結果

問題 No.950 行列累乗
ユーザー lam6er
提出日時 2025-03-20 18:57:46
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 6,349 bytes
コンパイル時間 215 ms
コンパイル使用メモリ 82,704 KB
実行使用メモリ 77,960 KB
最終ジャッジ日時 2025-03-20 18:58:49
合計ジャッジ時間 4,961 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 40 WA * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math
from math import sqrt, gcd

def input():
    return sys.stdin.read()

def mat_mult(a, b, p):
    return [
        [(a[0][0]*b[0][0] + a[0][1]*b[1][0]) % p,
         (a[0][0]*b[0][1] + a[0][1]*b[1][1]) % p],
        [(a[1][0]*b[0][0] + a[1][1]*b[1][0]) % p,
         (a[1][0]*b[0][1] + a[1][1]*b[1][1]) % p]
    ]

def mat_pow(mat, power, p):
    result = [[1,0], [0,1]]  # Identity matrix
    while power > 0:
        if power % 2 == 1:
            result = mat_mult(result, mat, p)
        mat = mat_mult(mat, mat, p)
        power //= 2
    return result

def baby_step_giant_step(a, b, p):
    if b == 1:
        return 0
    if a == 0:
        return -1
    m = int(math.isqrt(p)) + 1
    table = {}
    current = 1
    for j in range(m):
        if current not in table:
            table[current] = j
        current = (current * a) % p
    am_inv = pow(a, m * (p-2), p)
    gamma = b
    for i in range(m):
        if gamma in table:
            return i * m + table[gamma]
        gamma = (gamma * am_inv) % p
    return -1

def discrete_log(g, h, p):
    if g == 0:
        if h == 0:
            return 1
        else:
            return -1
    if h == 1:
        return 0
    if g == h:
        return 1
    return baby_step_giant_step(g, h, p)

def order(g, p):
    if g == 0:
        return -1
    if g % p == 0:
        return -1
    phi = p - 1
    factors = {}
    temp = phi
    for i in [2, 3, 5, 7]:
        while temp % i == 0:
            factors[i] = factors.get(i, 0) + 1
            temp //= i
    i = 11
    while i*i <= temp:
        while temp % i == 0:
            factors[i] = factors.get(i, 0) + 1
            temp //= i
        i += 2
    if temp > 1:
        factors[temp] = 1
    order = phi
    for prime in factors:
        test_order = order // prime
        if pow(g, test_order, p) == 1:
            order = test_order
            while order % prime == 0:
                candidate = order // prime
                if pow(g, candidate, p) != 1:
                    break
                order = candidate
    return order

def solve_case():
    data = input().split()
    ptr = 0
    p = int(data[ptr]); ptr +=1
    A = [[int(data[ptr])%p, int(data[ptr+1])%p], [int(data[ptr+2])%p, int(data[ptr+3])%p]]
    ptr +=4
    B = [[int(data[ptr])%p, int(data[ptr+1])%p], [int(data[ptr+2])%p, int(data[ptr+3])%p]]
    ptr +=4
    
    det_a = (A[0][0] * A[1][1] - A[0][1] * A[1][0]) % p
    det_b = (B[0][0] * B[1][1] - B[0][1] * B[1][0]) % p
    
    if det_a == 0:
        if det_b != 0:
            print(-1)
            return
        zero_A = all(A[i][j] == 0 for i in range(2) for j in range(2))
        if zero_A:
            zero_B = all(B[i][j] == 0 for i in range(2) for j in range(2))
            if zero_B:
                print(1)
                return
            else:
                print(-1)
                return
        non_zero_rows = []
        for i in range(2):
            for j in range(2):
                if A[i][j] != 0:
                    non_zero_rows.append((i, j))
        if not non_zero_rows:
            print(-1)
            return
        k_values = []
        for (i,j) in non_zero_rows:
            a_ij = A[i][j]
            b_ij = B[i][j]
            inv_a = pow(a_ij, p-2, p)
            k = (b_ij * inv_a) % p
            k_values.append(k)
        k = k_values[0]
        if not all(kv == k for kv in k_values):
            print(-1)
            return
        for i in range(2):
            for j in range(2):
                a = A[i][j]
                b = B[i][j]
                if (a * k) % p != b % p:
                    print(-1)
                    return
        t = (A[0][0] + A[1][1]) % p
        if t == 0:
            if k == 0:
                print(2)
                return
            elif k == 1:
                print(1)
                return
            else:
                print(-1)
                return
        x = discrete_log(t, k, p)
        if x == -1:
            print(-1)
            return
        else:
            n = x + 1
            if n <=0:
                print(-1)
                return
            current_k = pow(t, x, p)
            if current_k != k:
                print(-1)
                return
            print(n)
            return
    else:
        if det_b ==0:
            print(-1)
            return
        log_val = discrete_log(det_a, det_b, p)
        if log_val == -1:
            print(-1)
            return
        
        ord_deta = order(det_a, p)
        if ord_deta == -1:
            print(-1)
            return
        
        t = (A[0][0] + A[1][1]) % p
        d = det_a
        
        def compute_an_bn(n, p, t, d):
            if n == 1:
                a_n = 1 % p
                b_n = 0 % p
                return (a_n, b_n)
            if n == 2:
                a_n = t % p
                b_n = (-d) % p
                return (a_n, b_n)
            trans_mat = [[t % p, (-d) % p], [1, 0]]
            power = n - 2
            mat_power = mat_pow(trans_mat, power, p)
            a_2 = t % p
            a_1 = 1 % p
            a_n = (mat_power[0][0] * a_2 + mat_power[0][1] * a_1) % p
            a_n_minus_1 = (mat_power[1][0] * a_2 + mat_power[1][1] * a_1) % p
            b_n = (-d * a_n_minus_1) % p
            return (a_n, b_n)
        
        candidates = []
        if log_val <=0:
            log_val += ord_deta
        while log_val <=0:
            log_val += ord_deta
        
        current_n = log_val
        max_candidates = 100
        for i in range(max_candidates):
            candidates.append(current_n)
            current_n += ord_deta
            if current_n > 1e18:
                break
        
        min_n = -1
        for candidate in candidates:
            a_n, b_n = compute_an_bn(candidate, p, t, d)
            a11 = (a_n * A[0][0] + b_n) % p
            a12 = (a_n * A[0][1]) % p
            a21 = (a_n * A[1][0]) % p
            a22 = (a_n * A[1][1] + b_n) % p
            if (a11 == B[0][0] % p and a12 == B[0][1] % p and
                a21 == B[1][0] % p and a22 == B[1][1] % p):
                if min_n == -1 or candidate < min_n:
                    min_n = candidate
        if min_n != -1:
            print(min_n)
            return
        else:
            print(-1)
            return

solve_case()
0