結果
問題 | No.950 行列累乗 |
ユーザー | maspy |
提出日時 | 2019-12-14 00:17:32 |
言語 | Python3 (3.12.2 + numpy 1.26.4 + scipy 1.12.0) |
結果 |
AC
|
実行時間 | 732 ms / 2,000 ms |
コード長 | 4,389 bytes |
コンパイル時間 | 165 ms |
コンパイル使用メモリ | 13,312 KB |
実行使用メモリ | 71,964 KB |
最終ジャッジ日時 | 2024-06-27 23:13:39 |
合計ジャッジ時間 | 40,689 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge5 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 537 ms
44,916 KB |
testcase_01 | AC | 535 ms
44,924 KB |
testcase_02 | AC | 532 ms
44,788 KB |
testcase_03 | AC | 531 ms
44,416 KB |
testcase_04 | AC | 587 ms
44,796 KB |
testcase_05 | AC | 648 ms
56,840 KB |
testcase_06 | AC | 533 ms
45,048 KB |
testcase_07 | AC | 535 ms
44,916 KB |
testcase_08 | AC | 533 ms
44,656 KB |
testcase_09 | AC | 541 ms
44,656 KB |
testcase_10 | AC | 543 ms
44,660 KB |
testcase_11 | AC | 544 ms
44,788 KB |
testcase_12 | AC | 540 ms
44,920 KB |
testcase_13 | AC | 536 ms
44,920 KB |
testcase_14 | AC | 536 ms
45,048 KB |
testcase_15 | AC | 536 ms
44,664 KB |
testcase_16 | AC | 532 ms
44,536 KB |
testcase_17 | AC | 528 ms
44,788 KB |
testcase_18 | AC | 527 ms
44,396 KB |
testcase_19 | AC | 536 ms
44,916 KB |
testcase_20 | AC | 529 ms
44,776 KB |
testcase_21 | AC | 562 ms
47,348 KB |
testcase_22 | AC | 698 ms
67,176 KB |
testcase_23 | AC | 641 ms
57,844 KB |
testcase_24 | AC | 662 ms
62,988 KB |
testcase_25 | AC | 631 ms
58,604 KB |
testcase_26 | AC | 661 ms
64,068 KB |
testcase_27 | AC | 555 ms
45,120 KB |
testcase_28 | AC | 696 ms
66,196 KB |
testcase_29 | AC | 674 ms
63,228 KB |
testcase_30 | AC | 686 ms
63,616 KB |
testcase_31 | AC | 667 ms
62,568 KB |
testcase_32 | AC | 730 ms
71,148 KB |
testcase_33 | AC | 688 ms
65,428 KB |
testcase_34 | AC | 722 ms
70,988 KB |
testcase_35 | AC | 704 ms
65,368 KB |
testcase_36 | AC | 536 ms
44,664 KB |
testcase_37 | AC | 542 ms
46,580 KB |
testcase_38 | AC | 538 ms
46,848 KB |
testcase_39 | AC | 533 ms
45,048 KB |
testcase_40 | AC | 675 ms
67,384 KB |
testcase_41 | AC | 663 ms
60,868 KB |
testcase_42 | AC | 556 ms
47,520 KB |
testcase_43 | AC | 728 ms
71,848 KB |
testcase_44 | AC | 725 ms
71,480 KB |
testcase_45 | AC | 730 ms
71,704 KB |
testcase_46 | AC | 721 ms
71,964 KB |
testcase_47 | AC | 646 ms
63,436 KB |
testcase_48 | AC | 728 ms
71,960 KB |
testcase_49 | AC | 725 ms
71,960 KB |
testcase_50 | AC | 728 ms
71,584 KB |
testcase_51 | AC | 732 ms
71,556 KB |
testcase_52 | AC | 543 ms
47,736 KB |
testcase_53 | AC | 546 ms
47,356 KB |
testcase_54 | AC | 528 ms
44,660 KB |
testcase_55 | AC | 527 ms
45,044 KB |
testcase_56 | AC | 548 ms
47,484 KB |
testcase_57 | AC | 668 ms
60,800 KB |
testcase_58 | AC | 534 ms
44,404 KB |
testcase_59 | AC | 530 ms
44,924 KB |
testcase_60 | AC | 535 ms
44,664 KB |
ソースコード
import sys read = sys.stdin.buffer.read readline = sys.stdin.buffer.readline readlines = sys.stdin.buffer.readlines import numpy as np import itertools MOD = int(readline()) data = list(map(int,read().split())) A = np.int64(data[:4]).reshape(2,2) B = np.int64(data[4:]).reshape(2,2) def make_power_mat(A, L, MOD=MOD): N = A.shape[0] assert A.shape == (N,N) B = L.bit_length() x = np.empty((1<<B,N,N), np.int64) x[0] = np.eye(N,dtype=np.int64) X = A.copy() for n in range(B): for i,j in itertools.product(range(N),repeat=2): x[1<<n:1<<(n+1),i,j] = (x[:1<<n,i,:] * X[:,j] % MOD).sum(axis=1) % MOD Y = X.copy() for i,j in itertools.product(range(N),repeat=2): X[i,j] = (Y[i,:] * Y[:,j] % MOD).sum() % MOD return x[:L] def make_power(a, L, MOD=MOD): B = L.bit_length() x = np.empty((1<<B), np.int64) x[0] = 1 for n in range(B): x[1<<n:1<<(n+1)] = x[:1<<n] * a % MOD a *= a; a %= MOD return x[:L] def BSGS(a,b,MOD): assert a != 0 M = int(MOD ** .5 + 100) c = pow(int(a),M,MOD) d = pow(int(a),MOD-2,MOD) B = make_power(d,M,MOD) G = make_power(c,M,MOD) I = np.in1d(G, b * B % MOD) if not I.any(): return -1 q = np.where(I)[0][0] r = np.where(b * B % MOD == G[q])[0][0] return q * M + r def solve_naive(A,B,MOD): X = np.int64([[1,0],[0,1]]) for n in range(1,10**4): X = np.dot(X,A) % MOD if (X == B).all(): return n return -1 def matrix_power(A,n,MOD): if n == 0: return np.eye(2,dtype=np.int64) B = matrix_power(A,n//2,MOD) B = np.dot(B,B) % MOD return np.dot(A,B) % MOD if n & 1 else B def matrix_inverse(A,MOD): # powerでもよいが det = (A[0,0]*A[1,1] - A[1,0]*A[0,1]) % MOD B = np.array([ [A[1,1],-A[0,1]], [-A[1,0],A[0,0]], ]) x = pow(int(det),MOD-2,MOD) B *= x; B %= MOD assert np.all(np.dot(A,B) % MOD == np.eye(2,dtype=np.int64)) return B def solve(A,B,MOD): if MOD == 2: return solve_naive(A,B,MOD) detA = (A[0,0]*A[1,1] - A[1,0]*A[0,1]) % MOD detB = (B[0,0]*B[1,1] - B[1,0]*B[0,1]) % MOD trA = (A[0,0] + A[1,1]) % MOD if detA == 0 and trA == 0: # べき零 if (A == B).all(): return 1 if (B == 0).all(): return 2 return -1 if detA == 0: # A^n = t^{n-1}A I,J = np.where(A != 0); i = I[0]; j = J[0] a = A[i,j]; b = B[i,j] k = b * pow(int(a),MOD-2,MOD) % MOD if (k * A % MOD != B).any(): return -1 if k == 1: return 1 n = BSGS(trA,k,MOD) return -1 if n == -1 else n + 1 # det A == 1 に帰着したい n = BSGS(detA,detB,MOD) e = BSGS(detA,pow(int(detA),MOD-2,MOD),MOD) + 1 if n == -1: return -1 Ainv = matrix_inverse(A,MOD) B = np.dot(B,matrix_power(Ainv,n,MOD)) % MOD A = matrix_power(A,e,MOD) Ainv = matrix_inverse(A,MOD) if n == 0 and np.all(B == np.eye(2,dtype=np.int64)): k = solve_SL2(A,Ainv,MOD) + 1 return k * e k = solve_SL2(A,B,MOD) if k == -1: return -1 return e * k + n def to_hash(A): A = A.astype(object) return (A[:,0,0] << 96) + (A[:,1,0] << 64) + (A[:,0,1] << 32) + (A[:,1,1]) def solve_SL2(A,B,MOD): """ A の固有値を考えると、Aの位数は十分小さいことが分かる F_pで2固有値 → p-1周期 F_{p^2}で2固有値 → ノルム1にしたのでp+1周期 固有値が重複 → 2乗すると固有値が1 → 2p周期 2p程度で、BGSGをすればよい """ U = 2 * MOD M = int(U ** .5 + 100) c = matrix_power(A,M,MOD) d = matrix_inverse(A,MOD) Baby = make_power_mat(d,M,MOD) Giant = make_power_mat(c,M,MOD) BB = np.zeros((M,2,2),np.int64) for i,j in itertools.product(range(2),repeat=2): BB[:,i,j] = (Baby[:,i,:] * B[:,j][None,:] % MOD).sum(axis=1) % MOD Gh = to_hash(Giant).tolist() BBh = to_hash(BB).tolist() se = set(BBh) q = -1 for i,g in enumerate(Gh): if g in se: q = i break if q == -1: return -1 g = Gh[q] r = -1 for i,b in enumerate(BBh): if b == g: r = i break return q * M + r print(solve(A,B,MOD))