結果

問題 No.590 Replacement
ユーザー lam6er
提出日時 2025-03-26 15:56:12
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,574 bytes
コンパイル時間 337 ms
コンパイル使用メモリ 82,152 KB
実行使用メモリ 76,808 KB
最終ジャッジ日時 2025-03-26 15:56:44
合計ジャッジ時間 5,472 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 7 TLE * 1 -- * 39
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from math import gcd

def extended_gcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = extended_gcd(b % a, a)
        return (g, x - (b // a) * y, y)

def crt(a1, m1, a2, m2):
    g, x, y = extended_gcd(m1, m2)
    if (a2 - a1) % g != 0:
        return None
    lcm = m1 // g * m2
    x0 = (a1 + (x * (a2 - a1) // g) % (m2 // g) * m1) % lcm
    return x0

def find_cycles(perm):
    n = len(perm)
    visited = [False] * (n + 1)
    cycles = []
    for i in range(1, n + 1):
        if not visited[i]:
            cycle = []
            j = i
            while not visited[j]:
                visited[j] = True
                cycle.append(j)
                j = perm[j - 1]
            cycles.append(cycle)
    return cycles

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr +=1
    A = list(map(int, input[ptr:ptr+N]))
    ptr +=N
    B = list(map(int, input[ptr:ptr+N]))
    ptr +=N
    
    cycles_A = find_cycles(A)
    cycles_B = find_cycles(B)
    
    total = 0
    
    for cycleA in cycles_A:
        la = len(cycleA)
        posA = {num: idx for idx, num in enumerate(cycleA)}
        setA = set(cycleA)
        for cycleB in cycles_B:
            lb = len(cycleB)
            posB = {num: idx for idx, num in enumerate(cycleB)}
            setB = set(cycleB)
            S = setA & setB
            if not S:
                continue
            k = len(S)
            t1_z = {}
            for z in S:
                t1_z_z = {}
                for x in cycleA:
                    t1 = (cycleA.index(z) - cycleA.index(x)) % la
                    if x not in t1_z:
                        t1_z[x] = {}
                    t1_z[x][z] = t1
            t2_z = {}
            for z in S:
                t2_z_z = {}
                for y in cycleB:
                    t2 = (cycleB.index(z) - cycleB.index(y)) % lb
                    if y not in t2_z:
                        t2_z[y] = {}
                    t2_z[y][z] = t2
            
            for x in cycleA:
                for y in cycleB:
                    min_t = None
                    for z in S:
                        t1 = t1_z[x][z]
                        t2 = t2_z[y][z]
                        res = crt(t1, la, t2, lb)
                        if res is not None:
                            if min_t is None or res < min_t:
                                min_t = res
                    if min_t is not None:
                        total += min_t
    print(total)

if __name__ == '__main__':
    main()
0