結果

問題 No.950 行列累乗
ユーザー gew1fw
提出日時 2025-06-12 16:47:01
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 4,241 bytes
コンパイル時間 175 ms
コンパイル使用メモリ 82,168 KB
実行使用メモリ 141,736 KB
最終ジャッジ日時 2025-06-12 16:47:57
合計ジャッジ時間 5,882 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 30 WA * 27
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from math import gcd
from random import randint

MOD = 10**18 + 3

def modinv(a, mod):
    g, x, y = extended_gcd(a, mod)
    if g != 1:
        return None
    else:
        return x % mod

def extended_gcd(a, b):
    if a == 0:
        return (b, 0, 1)
    else:
        g, y, x = extended_gcd(b % a, a)
        return (g, x - (b // a) * y, y)

def matrix_mult(m1, m2, mod):
    res = [[0]*len(m2[0]) for _ in range(len(m1))]
    for i in range(len(m1)):
        for k in range(len(m2)):
            if m1[i][k] == 0:
                continue
            for j in range(len(m2[0])):
                res[i][j] = (res[i][j] + m1[i][k] * m2[k][j]) % mod
    return res

def matrix_pow(mat, power, mod):
    result = [[1 if i == j else 0 for j in range(len(mat))] for i in range(len(mat))]
    while power > 0:
        if power % 2 == 1:
            result = matrix_mult(result, mat, mod)
        mat = matrix_mult(mat, mat, mod)
        power //= 2
    return result

def solve_discrete_log(a, b, mod):
    if b == 1:
        return 0
    if a == 0:
        if b == 0:
            return 1
        else:
            return -1
    if mod == 1:
        return 0
    if gcd(a, mod) != 1:
        if b % gcd(a, mod) != 0:
            return -1
        a //= gcd(a, mod)
        b //= gcd(a, mod)
        mod //= gcd(a, mod)
    max_d = int(mod ** 0.5) + 1
    value = {}
    current = 1
    for i in range(max_d):
        if current not in value:
            value[current] = i
        current = (current * a) % mod
    a_inv = modinv(a, mod)
    if a_inv is None:
        return -1
    gamma = pow(a_inv, max_d, mod)
    current = b % mod
    for i in range(max_d):
        if current in value:
            return i * max_d + value[current]
        current = (current * gamma) % mod
    return -1

def main():
    p = int(sys.stdin.readline())
    A = []
    for _ in range(2):
        A.append(list(map(int, sys.stdin.readline().split())))
    B = []
    for _ in range(2):
        B.append(list(map(int, sys.stdin.readline().split())))
    
    a = None
    valid = True
    A12 = A[0][1] % p
    B12 = B[0][1] % p
    A21 = A[1][0] % p
    B21 = B[1][0] % p
    if A12 != 0:
        a_A12 = B12 * modinv(A12, p) % p
        a = a_A12
    if A21 != 0:
        a_A21 = B21 * modinv(A21, p) % p
        if a is None:
            a = a_A21
        else:
            if a != a_A21:
                print(-1)
                return
    if a is None:
        a = 0
    else:
        if A12 != 0:
            if (a * A12) % p != B12:
                print(-1)
                return
        if A21 != 0:
            if (a * A21) % p != B21:
                print(-1)
                return
    b = (B[0][0] - a * A[0][0]) % p
    if (a * A[1][1] + b) % p != B[1][1]:
        print(-1)
        return
    
    t = (A[0][0] + A[1][1]) % p
    d = (A[0][0] * A[1][1] - A[0][1] * A[1][0]) % p

    if d == 0:
        if b != 0:
            print(-1)
            return
        current_a = 1
        current_b = 0
        if t == 0:
            if a != 0:
                print(-1)
                return
            else:
                print(1)
                return
        if t == 1:
            if a == current_a:
                print(1)
                return
            else:
                print(-1)
                return
        n_minus_1 = solve_discrete_log(t, a, p)
        if n_minus_1 == -1:
            print(-1)
            return
        else:
            print(n_minus_1 + 1)
            return
    else:
        M = [[t, -d], [1, 0]]
        a1 = 1
        a2 = t * a1 + 0
        a2_mod = a2 % p
        if a2 == a and 0 == b:
            print(2)
            return
        a_list = [a1, a2_mod]
        n = 2
        max_tries = 1000000
        found = False
        while max_tries > 0:
            n += 1
            a_next = (t * a_list[-1] - d * a_list[-2]) % p
            a_list.append(a_next)
            b_next = (-d * a_list[-2]) % p
            if a_next == a and b_next == b:
                found = True
                break
            max_tries -= 1
        if found:
            print(n)
            return
        else:
            print(-1)
            return

if __name__ == '__main__':
    main()
0