結果
| 問題 |
No.950 行列累乗
|
| コンテスト | |
| ユーザー |
gew1fw
|
| 提出日時 | 2025-06-12 13:14:16 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 3,020 bytes |
| コンパイル時間 | 164 ms |
| コンパイル使用メモリ | 82,316 KB |
| 実行使用メモリ | 55,368 KB |
| 最終ジャッジ日時 | 2025-06-12 13:16:45 |
| 合計ジャッジ時間 | 4,618 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 23 WA * 34 |
ソースコード
import sys
import math
def input():
return sys.stdin.read()
def mat_mult(a, b, p):
return [
[(a[0][0]*b[0][0] + a[0][1]*b[1][0]) % p,
(a[0][0]*b[0][1] + a[0][1]*b[1][1]) % p],
[(a[1][0]*b[0][0] + a[1][1]*b[1][0]) % p,
(a[1][0]*b[0][1] + a[1][1]*b[1][1]) % p]
]
def mat_pow(mat, power, p):
result = [[1,0], [0,1]]
while power > 0:
if power % 2 == 1:
result = mat_mult(result, mat, p)
mat = mat_mult(mat, mat, p)
power //= 2
return result
def mat_eq(a, b):
return a[0] == b[0] and a[1] == b[1]
def det(mat, p):
return (mat[0][0] * mat[1][1] - mat[0][1] * mat[1][0]) % p
def solve():
data = input().split()
idx = 0
p = int(data[idx]); idx +=1
A = [
[int(data[idx])%p, int(data[idx+1])%p],
[int(data[idx+2])%p, int(data[idx+3])%p]
]
idx +=4
B = [
[int(data[idx])%p, int(data[idx+1])%p],
[int(data[idx+2])%p, int(data[idx+3])%p]
]
idx +=4
detA = det(A, p)
detB = det(B, p)
if detA % p == 0:
if detB % p != 0:
print(-1)
return
current = A
n = 1
if mat_eq(current, B):
print(1)
return
for n in range(2, 100):
current = mat_mult(current, A, p)
if mat_eq(current, B):
print(n)
return
if all(cell == 0 for row in current for cell in row):
break
print(-1)
return
t = (A[0][0] + A[1][1]) % p
d = detA % p
A12 = A[0][1]
A21 = A[1][0]
B12 = B[0][1]
B21 = B[1][0]
if A12 == 0 and B12 != 0:
print(-1)
return
if A21 == 0 and B21 != 0:
print(-1)
return
c_n = None
if A12 != 0:
inv_A12 = pow(A12, p-2, p)
c_n = (B12 * inv_A12) % p
if A21 != 0:
inv_A21 = pow(A21, p-2, p)
c_n2 = (B21 * inv_A21) % p
if c_n != c_n2:
print(-1)
return
else:
if A21 != 0:
inv_A21 = pow(A21, p-2, p)
c_n = (B21 * inv_A21) % p
else:
c_n = 0
s_n1 = (B[0][0] - c_n * A[0][0]) % p
s_n2 = (B[1][1] - c_n * A[1][1]) % p
if s_n1 != s_n2:
print(-1)
return
s_n = s_n1
a1 = 1
s1 = 0
a2 = t % p
s2 = (-d) % p
if a1 == c_n and s1 == s_n:
print(1)
return
if a2 == c_n and s2 == s_n:
print(2)
return
current_a = a2
current_s = s2
prev_a = a1
prev_s = s1
n = 3
found = False
for _ in range(1000):
new_a = (t * current_a - d * prev_a) % p
new_s = (t * current_s - d * prev_s) % p
if new_a == c_n and new_s == s_n:
found = True
break
prev_a, current_a = current_a, new_a
prev_s, current_s = current_s, new_s
n += 1
if found:
print(n)
return
print(-1)
solve()
gew1fw