結果
| 問題 | No.590 Replacement | 
| コンテスト | |
| ユーザー |  lam6er | 
| 提出日時 | 2025-03-26 15:56:12 | 
| 言語 | PyPy3 (7.3.15) | 
| 結果 | 
                                TLE
                                 
                             | 
| 実行時間 | - | 
| コード長 | 2,574 bytes | 
| コンパイル時間 | 337 ms | 
| コンパイル使用メモリ | 82,152 KB | 
| 実行使用メモリ | 76,808 KB | 
| 最終ジャッジ日時 | 2025-03-26 15:56:44 | 
| 合計ジャッジ時間 | 5,472 ms | 
| ジャッジサーバーID (参考情報) | judge1 / judge5 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| other | AC * 7 TLE * 1 -- * 39 | 
ソースコード
import sys
from math import gcd
def extended_gcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = extended_gcd(b % a, a)
        return (g, x - (b // a) * y, y)
def crt(a1, m1, a2, m2):
    g, x, y = extended_gcd(m1, m2)
    if (a2 - a1) % g != 0:
        return None
    lcm = m1 // g * m2
    x0 = (a1 + (x * (a2 - a1) // g) % (m2 // g) * m1) % lcm
    return x0
def find_cycles(perm):
    n = len(perm)
    visited = [False] * (n + 1)
    cycles = []
    for i in range(1, n + 1):
        if not visited[i]:
            cycle = []
            j = i
            while not visited[j]:
                visited[j] = True
                cycle.append(j)
                j = perm[j - 1]
            cycles.append(cycle)
    return cycles
def main():
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr +=1
    A = list(map(int, input[ptr:ptr+N]))
    ptr +=N
    B = list(map(int, input[ptr:ptr+N]))
    ptr +=N
    
    cycles_A = find_cycles(A)
    cycles_B = find_cycles(B)
    
    total = 0
    
    for cycleA in cycles_A:
        la = len(cycleA)
        posA = {num: idx for idx, num in enumerate(cycleA)}
        setA = set(cycleA)
        for cycleB in cycles_B:
            lb = len(cycleB)
            posB = {num: idx for idx, num in enumerate(cycleB)}
            setB = set(cycleB)
            S = setA & setB
            if not S:
                continue
            k = len(S)
            t1_z = {}
            for z in S:
                t1_z_z = {}
                for x in cycleA:
                    t1 = (cycleA.index(z) - cycleA.index(x)) % la
                    if x not in t1_z:
                        t1_z[x] = {}
                    t1_z[x][z] = t1
            t2_z = {}
            for z in S:
                t2_z_z = {}
                for y in cycleB:
                    t2 = (cycleB.index(z) - cycleB.index(y)) % lb
                    if y not in t2_z:
                        t2_z[y] = {}
                    t2_z[y][z] = t2
            
            for x in cycleA:
                for y in cycleB:
                    min_t = None
                    for z in S:
                        t1 = t1_z[x][z]
                        t2 = t2_z[y][z]
                        res = crt(t1, la, t2, lb)
                        if res is not None:
                            if min_t is None or res < min_t:
                                min_t = res
                    if min_t is not None:
                        total += min_t
    print(total)
if __name__ == '__main__':
    main()
            
            
            
        