結果

問題 No.950 行列累乗
ユーザー lam6er
提出日時 2025-03-31 17:32:21
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 5,930 bytes
コンパイル時間 137 ms
コンパイル使用メモリ 82,232 KB
実行使用メモリ 82,428 KB
最終ジャッジ日時 2025-03-31 17:33:18
合計ジャッジ時間 5,155 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 40 WA * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math

def input():
    return sys.stdin.read()

def matrices_equal(a, b, mod):
    for i in range(2):
        for j in range(2):
            if (a[i][j] % mod) != (b[i][j] % mod):
                return False
    return True

def matrix_mult(a, b, mod):
    return [
        [
            (a[0][0]*b[0][0] + a[0][1]*b[1][0]) % mod,
            (a[0][0]*b[0][1] + a[0][1]*b[1][1]) % mod
        ],
        [
            (a[1][0]*b[0][0] + a[1][1]*b[1][0]) % mod,
            (a[1][0]*b[0][1] + a[1][1]*b[1][1]) % mod
        ]
    ]

def matrix_power(mat, power, mod):
    result = [[1,0], [0,1]]
    while power > 0:
        if power % 2 == 1:
            result = matrix_mult(result, mat, mod)
        mat = matrix_mult(mat, mat, mod)
        power //= 2
    return result

def discrete_log(a, b, mod):
    if b == 1:
        return 0
    a %= mod
    b %= mod
    n = int(math.isqrt(mod)) + 1
    table = dict()
    current = 1
    for i in range(n):
        if current not in table:
            table[current] = i
        current = current * a % mod
    an = pow(a, n, mod)
    ainverse = pow(an, mod-2, mod)
    current = b
    for i in range(n):
        if current in table:
            return i * n + table[current]
        current = current * ainverse % mod
    return -1

def find_order(a, mod):
    if a == 0:
        return -1
    a %= mod
    if math.gcd(a, mod) != 1:
        return -1
    m = mod - 1
    factors = {}
    temp = m
    for i in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]:
        if temp == 1:
            break
        while temp % i == 0:
            factors[i] = factors.get(i, 0) + 1
            temp //= i
    if temp != 1:
        factors[temp] = 1
    divisors = [1]
    for p, exp in factors.items():
        current_length = len(divisors)
        for e in range(1, exp + 1):
            p_pow = p ** e
            for d in divisors[:current_length]:
                divisors.append(d * p_pow)
    divisors = sorted(list(set(divisors)))
    for d in divisors:
        if pow(a, d, mod) == 1:
            return d
    return m

def main():
    data = input().split()
    idx = 0
    p = int(data[idx])
    idx +=1
    A = []
    for _ in range(2):
        a11 = int(data[idx])
        a12 = int(data[idx+1])
        A.append([a11, a12])
        idx +=2
    B = []
    for _ in range(2):
        b11 = int(data[idx])
        b12 = int(data[idx+1])
        B.append([b11, b12])
        idx +=2

    det_A = (A[0][0] * A[1][1] - A[0][1] * A[1][0]) % p
    det_B = (B[0][0] * B[1][1] - B[0][1] * B[1][0]) % p

    if det_A == 0:
        if det_B != 0:
            print(-1)
            return
        is_zero_A = all(A[i][j] % p == 0 for i in range(2) for j in range(2))
        if is_zero_A:
            is_zero_B = all(B[i][j] % p == 0 for i in range(2) for j in range(2))
            if is_zero_B:
                print(1)
            else:
                print(-1)
            return
        else:
            tr_A = (A[0][0] + A[1][1]) % p
            if tr_A == 0:
                if matrices_equal(A, B, p):
                    print(1)
                else:
                    is_zero_B = all(B[i][j] % p == 0 for i in range(2) for j in range(2))
                    if is_zero_B:
                        print(2)
                    else:
                        print(-1)
                return
            else:
                k = None
                valid = True
                for i in range(2):
                    for j in range(2):
                        a_ij = A[i][j] % p
                        b_ij = B[i][j] % p
                        if a_ij == 0:
                            if b_ij != 0:
                                valid = False
                        else:
                            inv = pow(a_ij, p-2, p)
                            curr_k = (b_ij * inv) % p
                            if k is None:
                                k = curr_k
                            else:
                                if curr_k != k:
                                    valid = False
                if not valid:
                    print(-1)
                    return
                if k is None:
                    is_zero_B = all(B[i][j] % p == 0 for i in range(2) for j in range(2))
                    if is_zero_B:
                        k = 0
                    else:
                        print(-1)
                        return
                if k ==0:
                    if all(B[i][j] % p ==0 for i in range(2) for j in range(2)):
                        k=0
                    else:
                        print(-1)
                        return
                x = discrete_log(tr_A, k, p)
                if x == -1:
                    print(-1)
                else:
                    n = x +1
                    print(n if n >0 else -1)
                return
    else:
        if det_B ==0:
            print(-1)
            return
        d = det_A
        e = det_B
        x = discrete_log(d, e, p)
        if x == -1:
            print(-1)
            return
        order = find_order(d, p)
        candidates = []
        current_n = x
        max_candidates = 1000
        k =0
        while len(candidates) < max_candidates:
            n_candidate = current_n + k * order
            if n_candidate >0:
                candidates.append(n_candidate)
            if len(candidates) >= 1000:
                break
            k +=1
        min_n = None
        for candidate in candidates:
            if candidate <=0:
                continue
            A_pow = matrix_power(A, candidate, p)
            if matrices_equal(A_pow, B, p):
                if min_n is None or candidate < min_n:
                    min_n = candidate
                    break
        if min_n is not None:
            print(min_n)
        else:
            print(-1)
        return

if __name__ == "__main__":
    main()
0