結果

問題 No.950 行列累乗
ユーザー gew1fw
提出日時 2025-06-12 18:28:52
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,020 bytes
コンパイル時間 255 ms
コンパイル使用メモリ 82,356 KB
実行使用メモリ 55,556 KB
最終ジャッジ日時 2025-06-12 18:29:36
合計ジャッジ時間 4,316 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 23 WA * 34
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math

def input():
    return sys.stdin.read()

def mat_mult(a, b, p):
    return [
        [(a[0][0]*b[0][0] + a[0][1]*b[1][0]) % p,
         (a[0][0]*b[0][1] + a[0][1]*b[1][1]) % p],
        [(a[1][0]*b[0][0] + a[1][1]*b[1][0]) % p,
         (a[1][0]*b[0][1] + a[1][1]*b[1][1]) % p]
    ]

def mat_pow(mat, power, p):
    result = [[1,0], [0,1]]
    while power > 0:
        if power % 2 == 1:
            result = mat_mult(result, mat, p)
        mat = mat_mult(mat, mat, p)
        power //= 2
    return result

def mat_eq(a, b):
    return a[0] == b[0] and a[1] == b[1]

def det(mat, p):
    return (mat[0][0] * mat[1][1] - mat[0][1] * mat[1][0]) % p

def solve():
    data = input().split()
    idx = 0
    p = int(data[idx]); idx +=1
    A = [
        [int(data[idx])%p, int(data[idx+1])%p],
        [int(data[idx+2])%p, int(data[idx+3])%p]
    ]
    idx +=4
    B = [
        [int(data[idx])%p, int(data[idx+1])%p],
        [int(data[idx+2])%p, int(data[idx+3])%p]
    ]
    idx +=4

    detA = det(A, p)
    detB = det(B, p)

    if detA % p == 0:
        if detB % p != 0:
            print(-1)
            return
        current = A
        n = 1
        if mat_eq(current, B):
            print(1)
            return
        for n in range(2, 100):
            current = mat_mult(current, A, p)
            if mat_eq(current, B):
                print(n)
                return
            if all(cell == 0 for row in current for cell in row):
                break
        print(-1)
        return

    t = (A[0][0] + A[1][1]) % p
    d = detA % p

    A12 = A[0][1]
    A21 = A[1][0]
    B12 = B[0][1]
    B21 = B[1][0]

    if A12 == 0 and B12 != 0:
        print(-1)
        return
    if A21 == 0 and B21 != 0:
        print(-1)
        return

    c_n = None
    if A12 != 0:
        inv_A12 = pow(A12, p-2, p)
        c_n = (B12 * inv_A12) % p
        if A21 != 0:
            inv_A21 = pow(A21, p-2, p)
            c_n2 = (B21 * inv_A21) % p
            if c_n != c_n2:
                print(-1)
                return
    else:
        if A21 != 0:
            inv_A21 = pow(A21, p-2, p)
            c_n = (B21 * inv_A21) % p
        else:
            c_n = 0

    s_n1 = (B[0][0] - c_n * A[0][0]) % p
    s_n2 = (B[1][1] - c_n * A[1][1]) % p
    if s_n1 != s_n2:
        print(-1)
        return
    s_n = s_n1

    a1 = 1
    s1 = 0
    a2 = t % p
    s2 = (-d) % p

    if a1 == c_n and s1 == s_n:
        print(1)
        return
    if a2 == c_n and s2 == s_n:
        print(2)
        return

    current_a = a2
    current_s = s2
    prev_a = a1
    prev_s = s1
    n = 3
    found = False
    for _ in range(1000):
        new_a = (t * current_a - d * prev_a) % p
        new_s = (t * current_s - d * prev_s) % p
        if new_a == c_n and new_s == s_n:
            found = True
            break
        prev_a, current_a = current_a, new_a
        prev_s, current_s = current_s, new_s
        n += 1

    if found:
        print(n)
        return

    print(-1)

solve()
0