結果

問題 No.950 行列累乗
ユーザー gew1fw
提出日時 2025-06-12 18:27:57
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 8,065 bytes
コンパイル時間 208 ms
コンパイル使用メモリ 82,112 KB
実行使用メモリ 83,508 KB
最終ジャッジ日時 2025-06-12 18:28:04
合計ジャッジ時間 6,102 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 37 WA * 20
権限があれば一括ダウンロードができます

ソースコード

diff #

import math
import random

def multiply(a, b, p):
    res = [[0]*2 for _ in range(2)]
    res[0][0] = (a[0][0] * b[0][0] + a[0][1] * b[1][0]) % p
    res[0][1] = (a[0][0] * b[0][1] + a[0][1] * b[1][1]) % p
    res[1][0] = (a[1][0] * b[0][0] + a[1][1] * b[1][0]) % p
    res[1][1] = (a[1][0] * b[0][1] + a[1][1] * b[1][1]) % p
    return res

def matrix_pow(mat, power, p):
    result = [[1, 0], [0, 1]]  # Identity matrix
    while power > 0:
        if power % 2 == 1:
            result = multiply(result, mat, p)
        mat = multiply(mat, mat, p)
        power //= 2
    return result

def is_prime(n):
    if n < 2:
        return False
    for p in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]:
        if n % p == 0:
            return n == p
    d = n - 1
    s = 0
    while d % 2 == 0:
        d //= 2
        s += 1
    for a in [2, 3, 5, 7, 11]:
        if a >= n:
            continue
        x = pow(a, d, n)
        if x == 1 or x == n - 1:
            continue
        for _ in range(s - 1):
            x = pow(x, 2, n)
            if x == n - 1:
                break
        else:
            return False
    return True

def pollards_rho(n):
    if n % 2 == 0:
        return 2
    if n % 3 == 0:
        return 3
    if n % 5 == 0:
        return 5
    while True:
        c = random.randint(1, n - 1)
        f = lambda x: (pow(x, 2, n) + c) % n
        x, y, d = 2, 2, 1
        while d == 1:
            x = f(x)
            y = f(f(y))
            d = math.gcd((x - y) % n, n)
        if d != n:
            return d

def factor(n):
    factors = []
    def _factor(n):
        if n == 1:
            return
        if is_prime(n):
            factors.append(n)
            return
        d = pollards_rho(n)
        _factor(d)
        _factor(n // d)
    _factor(n)
    factors.sort()
    return factors

def prime_factors(n):
    factors = factor(n)
    res = {}
    for p in factors:
        res[p] = res.get(p, 0) + 1
    return res

def compute_order(a, p):
    if a == 0 or p <= 1 or math.gcd(a, p) != 1:
        return -1
    m = p - 1
    factors = prime_factors(m)
    unique_factors = list(factors.keys())
    for factor_p in unique_factors:
        exponent = factors[factor_p]
        for e in range(exponent, 0, -1):
            if pow(a, m // (factor_p ** e), p) == 1:
                m = m // (factor_p ** e)
                break
    return m

def baby_step_giant_step(a, b, p):
    if b == 1:
        return 0
    if a == 0:
        return -1
    a = a % p
    b = b % p
    n = int(math.isqrt(p)) + 1
    table = {}
    current = 1
    for i in range(n):
        if current not in table:
            table[current] = i
        current = (current * a) % p
    an = pow(a, n, p)
    an = pow(an, p-2, p)
    current = b
    for i in range(n):
        if current in table:
            return i * n + table[current]
        current = (current * an) % p
    return -1

def discrete_log(a, b, p):
    if p == 1:
        return 0 if b == 1 else -1
    if a == 0:
        if b == 0:
            return 1
        else:
            return -1
    a = a % p
    b = b % p
    if b == 1:
        return 0
    if a == 0:
        return -1
    return baby_step_giant_step(a, b, p)

def main():
    p = int(input())
    A = []
    for _ in range(2):
        A.append(list(map(int, input().split())))
    B = []
    for _ in range(2):
        B.append(list(map(int, input().split())))
    
    # Compute determinants
    det_A = (A[0][0] * A[1][1] - A[0][1] * A[1][0]) % p
    det_B = (B[0][0] * B[1][1] - B[0][1] * B[1][0]) % p
    
    if det_A != 0:
        # Check if det_B is possible
        log_val = discrete_log(det_A, det_B, p)
        if log_val == -1:
            print(-1)
            return
        # Check if A^log_val == B
        a_pow = matrix_pow(A, log_val, p)
        if a_pow == B:
            print(log_val)
            return
        # Find the order of det_A
        ord_det = compute_order(det_A, p)
        if ord_det == -1:
            print(-1)
            return
        current_x = log_val
        min_x = None
        for _ in range(100):
            current_x += ord_det
            a_pow = matrix_pow(A, current_x, p)
            if a_pow == B:
                if min_x is None or current_x < min_x:
                    min_x = current_x
        if min_x is not None:
            print(min_x)
            return
        else:
            print(-1)
            return
    else:
        if det_B != 0:
            print(-1)
            return
        # Check if A^2 is zero matrix
        A_sq = multiply(A, A, p)
        is_zero = True
        for i in range(2):
            for j in range(2):
                if A_sq[i][j] != 0:
                    is_zero = False
        if is_zero:
            # Check if B is zero matrix or A
            is_B_zero = True
            for i in range(2):
                for j in range(2):
                    if B[i][j] != 0:
                        is_B_zero = False
            if is_B_zero:
                print(2)
                return
            if A == B:
                print(1)
                return
            else:
                print(-1)
                return
        # Check if A^2 = c*A
        c = None
        for i in range(2):
            for j in range(2):
                if A[i][j] == 0:
                    if A_sq[i][j] != 0:
                        c = None
                        break
                else:
                    inv = pow(A[i][j], p-2, p)
                    current_c = (A_sq[i][j] * inv) % p
                    if c is None:
                        c = current_c
                    else:
                        if current_c != c:
                            c = None
                            break
            if c is None:
                break
        if c is not None:
            # Check if B is k*A
            k = None
            for i in range(2):
                for j in range(2):
                    if A[i][j] == 0:
                        if B[i][j] != 0:
                            k = None
                            break
                    else:
                        inv = pow(A[i][j], p-2, p)
                        current_k = (B[i][j] * inv) % p
                        if k is None:
                            k = current_k
                        else:
                            if current_k != k:
                                k = None
                                break
                if k is None:
                    break
            if k is not None:
                if c == 0:
                    if k == 0:
                        # Find minimal n >=2
                        a_pow = matrix_pow(A, 2, p)
                        if a_pow == B:
                            print(2)
                            return
                        else:
                            print(-1)
                            return
                    else:
                        print(-1)
                        return
                else:
                    log_val = discrete_log(c, k, p)
                    if log_val == -1:
                        print(-1)
                        return
                    n_candidate = log_val + 1
                    if n_candidate < 1:
                        print(-1)
                        return
                    a_pow = matrix_pow(A, n_candidate, p)
                    if a_pow == B:
                        print(n_candidate)
                        return
                    else:
                        print(-1)
                        return
            else:
                print(-1)
                return
        else:
            # Check small n
            current = A
            if current == B:
                print(1)
                return
            for n in range(2, 10**5+1):
                current = multiply(current, A, p)
                if current == B:
                    print(n)
                    return
            print(-1)
            return

if __name__ == "__main__":
    main()
0