結果
問題 |
No.950 行列累乗
|
ユーザー |
![]() |
提出日時 | 2025-06-12 18:28:52 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 3,020 bytes |
コンパイル時間 | 255 ms |
コンパイル使用メモリ | 82,356 KB |
実行使用メモリ | 55,556 KB |
最終ジャッジ日時 | 2025-06-12 18:29:36 |
合計ジャッジ時間 | 4,316 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 23 WA * 34 |
ソースコード
import sys import math def input(): return sys.stdin.read() def mat_mult(a, b, p): return [ [(a[0][0]*b[0][0] + a[0][1]*b[1][0]) % p, (a[0][0]*b[0][1] + a[0][1]*b[1][1]) % p], [(a[1][0]*b[0][0] + a[1][1]*b[1][0]) % p, (a[1][0]*b[0][1] + a[1][1]*b[1][1]) % p] ] def mat_pow(mat, power, p): result = [[1,0], [0,1]] while power > 0: if power % 2 == 1: result = mat_mult(result, mat, p) mat = mat_mult(mat, mat, p) power //= 2 return result def mat_eq(a, b): return a[0] == b[0] and a[1] == b[1] def det(mat, p): return (mat[0][0] * mat[1][1] - mat[0][1] * mat[1][0]) % p def solve(): data = input().split() idx = 0 p = int(data[idx]); idx +=1 A = [ [int(data[idx])%p, int(data[idx+1])%p], [int(data[idx+2])%p, int(data[idx+3])%p] ] idx +=4 B = [ [int(data[idx])%p, int(data[idx+1])%p], [int(data[idx+2])%p, int(data[idx+3])%p] ] idx +=4 detA = det(A, p) detB = det(B, p) if detA % p == 0: if detB % p != 0: print(-1) return current = A n = 1 if mat_eq(current, B): print(1) return for n in range(2, 100): current = mat_mult(current, A, p) if mat_eq(current, B): print(n) return if all(cell == 0 for row in current for cell in row): break print(-1) return t = (A[0][0] + A[1][1]) % p d = detA % p A12 = A[0][1] A21 = A[1][0] B12 = B[0][1] B21 = B[1][0] if A12 == 0 and B12 != 0: print(-1) return if A21 == 0 and B21 != 0: print(-1) return c_n = None if A12 != 0: inv_A12 = pow(A12, p-2, p) c_n = (B12 * inv_A12) % p if A21 != 0: inv_A21 = pow(A21, p-2, p) c_n2 = (B21 * inv_A21) % p if c_n != c_n2: print(-1) return else: if A21 != 0: inv_A21 = pow(A21, p-2, p) c_n = (B21 * inv_A21) % p else: c_n = 0 s_n1 = (B[0][0] - c_n * A[0][0]) % p s_n2 = (B[1][1] - c_n * A[1][1]) % p if s_n1 != s_n2: print(-1) return s_n = s_n1 a1 = 1 s1 = 0 a2 = t % p s2 = (-d) % p if a1 == c_n and s1 == s_n: print(1) return if a2 == c_n and s2 == s_n: print(2) return current_a = a2 current_s = s2 prev_a = a1 prev_s = s1 n = 3 found = False for _ in range(1000): new_a = (t * current_a - d * prev_a) % p new_s = (t * current_s - d * prev_s) % p if new_a == c_n and new_s == s_n: found = True break prev_a, current_a = current_a, new_a prev_s, current_s = current_s, new_s n += 1 if found: print(n) return print(-1) solve()