結果

問題 No.950 行列累乗
コンテスト
ユーザー gew1fw
提出日時 2025-06-12 18:41:04
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 6,589 bytes
コンパイル時間 193 ms
コンパイル使用メモリ 82,280 KB
実行使用メモリ 513,720 KB
最終ジャッジ日時 2025-06-12 18:41:15
合計ジャッジ時間 6,801 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 4
other MLE * 1 -- * 56
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math
from collections import defaultdict

MOD = 10**9 + 7

def readints():
    return list(map(int, sys.stdin.readline().split()))

def matrix_mult(a, b, p):
    res = [[0]*2 for _ in range(2)]
    for i in range(2):
        for j in range(2):
            for k in range(2):
                res[i][j] = (res[i][j] + a[i][k] * b[k][j]) % p
    return res

def matrix_pow(mat, power, p):
    result = [[1 if i == j else 0 for j in range(2)] for i in range(2)]
    while power > 0:
        if power % 2 == 1:
            result = matrix_mult(result, mat, p)
        mat = matrix_mult(mat, mat, p)
        power //= 2
    return result

def modinv(a, p):
    g, x, y = extended_gcd(a, p)
    if g != 1:
        return None
    else:
        return x % p

def extended_gcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = extended_gcd(b % a, a)
        return (g, x - (b // a) * y, y)

def main():
    p = int(sys.stdin.readline())
    A = []
    for _ in range(2):
        row = list(map(int, sys.stdin.readline().split()))
        A.append([x % p for x in row])
    B = []
    for _ in range(2):
        row = list(map(int, sys.stdin.readline().split()))
        B.append([x % p for x in row])

    if A == [[1,0],[0,1]]:
        if B == A:
            print(1)
        else:
            print(-1)
        return

    if A == [[0,0],[0,0]]:
        if B == A:
            print(1)
        else:
            print(-1)
        return

    def is_zero_matrix(mat):
        return all(x == 0 for row in mat for x in row)

    if is_zero_matrix(A):
        if B == A:
            print(1)
        else:
            print(-1)
        return

    det_A = (A[0][0] * A[1][1] - A[0][1] * A[1][0]) % p
    tr_A = (A[0][0] + A[1][1]) % p

    if det_A == 0:
        seen = {}
        current = A
        n = 1
        seen_key = tuple(map(tuple, current))
        seen[seen_key] = n
        found = False
        while True:
            if current == B:
                print(n)
                found = True
                break
            next_mat = matrix_mult(current, A, p)
            key = tuple(map(tuple, next_mat))
            if key in seen:
                break
            seen[key] = n + 1
            current = next_mat
            n += 1
            if n > p ** 2:
                break
        if not found:
            print(-1)
        return

    a = tr_A
    b = -det_A

    def compute_an_bn(n):
        if n == 0:
            return (0, 1)
        elif n == 1:
            return (1, 0)
        else:
            M = [[a, 1], [b, 0]]
            pow_M = matrix_pow(M, n-1, MOD)
            an = (pow_M[0][0] * 1 + pow_M[0][1] * 0) % MOD
            bn = (pow_M[1][0] * 1 + pow_M[1][1] * 0) % MOD
            return (an % p, bn % p)

    A11, A12 = A[0][0], A[0][1]
    A21, A22 = A[1][0], A[1][1]

    B11, B12 = B[0][0], B[0][1]
    B21, B22 = B[1][0], B[1][1]

    possible_a = None
    possible_b = None

    if A12 == 0:
        if B12 != 0:
            print(-1)
            return
    else:
        if B12 % p != 0:
            inv_A12 = modinv(A12, p)
            if inv_A12 is None:
                print(-1)
                return
            a_val = (B12 * inv_A12) % p
            possible_a = a_val
        else:
            possible_a = 0

    if A21 == 0:
        if B21 != 0:
            print(-1)
            return
    else:
        if B21 % p != 0:
            inv_A21 = modinv(A21, p)
            if inv_A21 is None:
                print(-1)
                return
            a_val = (B21 * inv_A21) % p
            if possible_a is not None and possible_a != a_val:
                print(-1)
                return
            possible_a = a_val
        else:
            if possible_a is not None and possible_a != 0:
                print(-1)
                return
            possible_a = 0

    if possible_a is None:
        pass
    else:
        a_n = possible_a
        b_n1 = (B11 - a_n * A11) % p
        b_n2 = (B22 - a_n * A22) % p
        if b_n1 != b_n2:
            print(-1)
            return
        b_n = b_n1

        target_an = a_n
        target_bn = b_n

        seen = {}
        current_a, current_b = 1, 0
        seen[(current_a, current_b)] = 1
        found_n = -1

        for step in range(1, 2 * p + 1):
            if current_a == target_an and current_b == target_bn:
                found_n = step
                break
            next_a = (current_a * tr_A + current_b) % p
            next_b = (-current_a * det_A) % p
            current_a, current_b = next_a, next_b
            if (current_a, current_b) in seen:
                break
            seen[(current_a, current_b)] = step + 1

        if found_n != -1:
            print(found_n)
            return

    possible_an = None
    possible_bn = None

    if (A11 - A22) % p == 0:
        if (B11 - B22) % p != 0:
            print(-1)
            return
    else:
        numerator = (B11 - B22) % p
        denominator = (A11 - A22) % p
        if denominator == 0:
            if numerator != 0:
                print(-1)
                return
        else:
            inv_denominator = modinv(denominator, p)
            if inv_denominator is None:
                print(-1)
                return
            possible_an = (numerator * inv_denominator) % p

    if possible_an is not None:
        a_n = possible_an
        b_n = (B11 - a_n * A11) % p

        seen = {}
        current_a, current_b = 1, 0
        seen[(current_a, current_b)] = 1

        found_n = -1
        for step in range(1, 2 * p + 1):
            if current_a == a_n and current_b == b_n:
                found_n = step
                break
            next_a = (current_a * tr_A + current_b) % p
            next_b = (-current_a * det_A) % p
            current_a, current_b = next_a, next_b
            if (current_a, current_b) in seen:
                break
            seen[(current_a, current_b)] = step + 1

        if found_n != -1:
            print(found_n)
            return

    seen = {}
    current = A
    n = 1
    seen_key = tuple(map(tuple, current))
    seen[seen_key] = n
    found = False
    while True:
        if current == B:
            print(n)
            found = True
            break
        next_mat = matrix_mult(current, A, p)
        key = tuple(map(tuple, next_mat))
        if key in seen:
            break
        seen[key] = n + 1
        current = next_mat
        n += 1
        if n > p ** 2:
            break
    if not found:
        print(-1)

if __name__ == '__main__':
    main()
0