結果

問題 No.950 行列累乗
ユーザー maspymaspy
提出日時 2019-12-14 00:17:32
言語 Python3
(3.12.2 + numpy 1.26.4 + scipy 1.12.0)
結果
AC  
実行時間 732 ms / 2,000 ms
コード長 4,389 bytes
コンパイル時間 165 ms
コンパイル使用メモリ 13,312 KB
実行使用メモリ 71,964 KB
最終ジャッジ日時 2024-06-27 23:13:39
合計ジャッジ時間 40,689 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 537 ms
44,916 KB
testcase_01 AC 535 ms
44,924 KB
testcase_02 AC 532 ms
44,788 KB
testcase_03 AC 531 ms
44,416 KB
testcase_04 AC 587 ms
44,796 KB
testcase_05 AC 648 ms
56,840 KB
testcase_06 AC 533 ms
45,048 KB
testcase_07 AC 535 ms
44,916 KB
testcase_08 AC 533 ms
44,656 KB
testcase_09 AC 541 ms
44,656 KB
testcase_10 AC 543 ms
44,660 KB
testcase_11 AC 544 ms
44,788 KB
testcase_12 AC 540 ms
44,920 KB
testcase_13 AC 536 ms
44,920 KB
testcase_14 AC 536 ms
45,048 KB
testcase_15 AC 536 ms
44,664 KB
testcase_16 AC 532 ms
44,536 KB
testcase_17 AC 528 ms
44,788 KB
testcase_18 AC 527 ms
44,396 KB
testcase_19 AC 536 ms
44,916 KB
testcase_20 AC 529 ms
44,776 KB
testcase_21 AC 562 ms
47,348 KB
testcase_22 AC 698 ms
67,176 KB
testcase_23 AC 641 ms
57,844 KB
testcase_24 AC 662 ms
62,988 KB
testcase_25 AC 631 ms
58,604 KB
testcase_26 AC 661 ms
64,068 KB
testcase_27 AC 555 ms
45,120 KB
testcase_28 AC 696 ms
66,196 KB
testcase_29 AC 674 ms
63,228 KB
testcase_30 AC 686 ms
63,616 KB
testcase_31 AC 667 ms
62,568 KB
testcase_32 AC 730 ms
71,148 KB
testcase_33 AC 688 ms
65,428 KB
testcase_34 AC 722 ms
70,988 KB
testcase_35 AC 704 ms
65,368 KB
testcase_36 AC 536 ms
44,664 KB
testcase_37 AC 542 ms
46,580 KB
testcase_38 AC 538 ms
46,848 KB
testcase_39 AC 533 ms
45,048 KB
testcase_40 AC 675 ms
67,384 KB
testcase_41 AC 663 ms
60,868 KB
testcase_42 AC 556 ms
47,520 KB
testcase_43 AC 728 ms
71,848 KB
testcase_44 AC 725 ms
71,480 KB
testcase_45 AC 730 ms
71,704 KB
testcase_46 AC 721 ms
71,964 KB
testcase_47 AC 646 ms
63,436 KB
testcase_48 AC 728 ms
71,960 KB
testcase_49 AC 725 ms
71,960 KB
testcase_50 AC 728 ms
71,584 KB
testcase_51 AC 732 ms
71,556 KB
testcase_52 AC 543 ms
47,736 KB
testcase_53 AC 546 ms
47,356 KB
testcase_54 AC 528 ms
44,660 KB
testcase_55 AC 527 ms
45,044 KB
testcase_56 AC 548 ms
47,484 KB
testcase_57 AC 668 ms
60,800 KB
testcase_58 AC 534 ms
44,404 KB
testcase_59 AC 530 ms
44,924 KB
testcase_60 AC 535 ms
44,664 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
read = sys.stdin.buffer.read
readline = sys.stdin.buffer.readline
readlines = sys.stdin.buffer.readlines

import numpy as np
import itertools

MOD = int(readline())
data = list(map(int,read().split()))

A = np.int64(data[:4]).reshape(2,2)
B = np.int64(data[4:]).reshape(2,2)

def make_power_mat(A, L, MOD=MOD):
    N = A.shape[0]
    assert A.shape == (N,N)
    B = L.bit_length()
    x = np.empty((1<<B,N,N), np.int64)
    x[0] = np.eye(N,dtype=np.int64)
    X = A.copy()
    for n in range(B):
        for i,j in itertools.product(range(N),repeat=2):
            x[1<<n:1<<(n+1),i,j] = (x[:1<<n,i,:] * X[:,j] % MOD).sum(axis=1) % MOD
        Y = X.copy()
        for i,j in itertools.product(range(N),repeat=2):
            X[i,j] = (Y[i,:] * Y[:,j] % MOD).sum() % MOD
    return x[:L]

def make_power(a, L, MOD=MOD):
    B = L.bit_length()
    x = np.empty((1<<B), np.int64)
    x[0] = 1
    for n in range(B):
        x[1<<n:1<<(n+1)] = x[:1<<n] * a % MOD
        a *= a; a %= MOD
    return x[:L]

def BSGS(a,b,MOD):
    assert a != 0
    M = int(MOD ** .5 + 100)
    c = pow(int(a),M,MOD)
    d = pow(int(a),MOD-2,MOD)
    B = make_power(d,M,MOD)
    G = make_power(c,M,MOD)
    I = np.in1d(G, b * B % MOD)
    if not I.any():
        return -1
    q = np.where(I)[0][0]
    r = np.where(b * B % MOD == G[q])[0][0]
    return q * M + r

def solve_naive(A,B,MOD):
    X = np.int64([[1,0],[0,1]])
    for n in range(1,10**4):
        X = np.dot(X,A) % MOD
        if (X == B).all():
            return n
    return -1

def matrix_power(A,n,MOD):
    if n == 0:
        return np.eye(2,dtype=np.int64)
    B = matrix_power(A,n//2,MOD)
    B = np.dot(B,B) % MOD
    return np.dot(A,B) % MOD if n & 1 else B

def matrix_inverse(A,MOD):
    # powerでもよいが
    det = (A[0,0]*A[1,1] - A[1,0]*A[0,1]) % MOD
    B = np.array([
        [A[1,1],-A[0,1]],
        [-A[1,0],A[0,0]],
    ])
    x = pow(int(det),MOD-2,MOD)
    B *= x; B %= MOD
    assert np.all(np.dot(A,B) % MOD == np.eye(2,dtype=np.int64))
    return B

def solve(A,B,MOD):
    if MOD == 2:
        return solve_naive(A,B,MOD)
    detA = (A[0,0]*A[1,1] - A[1,0]*A[0,1]) % MOD
    detB = (B[0,0]*B[1,1] - B[1,0]*B[0,1]) % MOD
    trA = (A[0,0] + A[1,1]) % MOD
    if detA == 0 and trA == 0:
        # べき零
        if (A == B).all():
            return 1
        if (B == 0).all():
            return 2
        return -1
    if detA == 0:
        # A^n = t^{n-1}A
        I,J = np.where(A != 0); i = I[0]; j = J[0]
        a = A[i,j]; b = B[i,j]
        k = b * pow(int(a),MOD-2,MOD) % MOD
        if (k * A % MOD != B).any():
            return -1
        if k == 1:
            return 1
        n = BSGS(trA,k,MOD)
        return -1 if n == -1 else n + 1
    # det A == 1 に帰着したい
    n = BSGS(detA,detB,MOD)
    e = BSGS(detA,pow(int(detA),MOD-2,MOD),MOD) + 1
    if n == -1:
        return -1
    Ainv = matrix_inverse(A,MOD)
    B = np.dot(B,matrix_power(Ainv,n,MOD)) % MOD
    A = matrix_power(A,e,MOD)
    Ainv = matrix_inverse(A,MOD)
    if n == 0 and np.all(B == np.eye(2,dtype=np.int64)):
        k = solve_SL2(A,Ainv,MOD) + 1
        return k * e
    k = solve_SL2(A,B,MOD)
    if k == -1:
        return -1
    return e * k + n

def to_hash(A):
    A = A.astype(object)
    return (A[:,0,0] << 96) + (A[:,1,0] << 64) + (A[:,0,1] << 32) + (A[:,1,1])

def solve_SL2(A,B,MOD):
    """
    A の固有値を考えると、Aの位数は十分小さいことが分かる
    F_pで2固有値 → p-1周期
    F_{p^2}で2固有値 → ノルム1にしたのでp+1周期
    固有値が重複 → 2乗すると固有値が1 → 2p周期
    2p程度で、BGSGをすればよい
    """
    U = 2 * MOD
    M = int(U ** .5 + 100)
    c = matrix_power(A,M,MOD)
    d = matrix_inverse(A,MOD)
    Baby = make_power_mat(d,M,MOD)
    Giant = make_power_mat(c,M,MOD)
    BB = np.zeros((M,2,2),np.int64)
    for i,j in itertools.product(range(2),repeat=2):
        BB[:,i,j] = (Baby[:,i,:] * B[:,j][None,:] % MOD).sum(axis=1) % MOD
    Gh = to_hash(Giant).tolist()
    BBh = to_hash(BB).tolist()
    se = set(BBh)
    q = -1
    for i,g in enumerate(Gh):
        if g in se:
            q = i
            break
    if q == -1:
        return -1
    g = Gh[q]
    r = -1
    for i,b in enumerate(BBh):
        if b == g:
            r = i
            break
    return q * M + r

print(solve(A,B,MOD))
0