結果
| 問題 |
No.950 行列累乗
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-20 18:57:46 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 6,349 bytes |
| コンパイル時間 | 215 ms |
| コンパイル使用メモリ | 82,704 KB |
| 実行使用メモリ | 77,960 KB |
| 最終ジャッジ日時 | 2025-03-20 18:58:49 |
| 合計ジャッジ時間 | 4,961 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 4 |
| other | AC * 40 WA * 17 |
ソースコード
import sys
import math
from math import sqrt, gcd
def input():
return sys.stdin.read()
def mat_mult(a, b, p):
return [
[(a[0][0]*b[0][0] + a[0][1]*b[1][0]) % p,
(a[0][0]*b[0][1] + a[0][1]*b[1][1]) % p],
[(a[1][0]*b[0][0] + a[1][1]*b[1][0]) % p,
(a[1][0]*b[0][1] + a[1][1]*b[1][1]) % p]
]
def mat_pow(mat, power, p):
result = [[1,0], [0,1]] # Identity matrix
while power > 0:
if power % 2 == 1:
result = mat_mult(result, mat, p)
mat = mat_mult(mat, mat, p)
power //= 2
return result
def baby_step_giant_step(a, b, p):
if b == 1:
return 0
if a == 0:
return -1
m = int(math.isqrt(p)) + 1
table = {}
current = 1
for j in range(m):
if current not in table:
table[current] = j
current = (current * a) % p
am_inv = pow(a, m * (p-2), p)
gamma = b
for i in range(m):
if gamma in table:
return i * m + table[gamma]
gamma = (gamma * am_inv) % p
return -1
def discrete_log(g, h, p):
if g == 0:
if h == 0:
return 1
else:
return -1
if h == 1:
return 0
if g == h:
return 1
return baby_step_giant_step(g, h, p)
def order(g, p):
if g == 0:
return -1
if g % p == 0:
return -1
phi = p - 1
factors = {}
temp = phi
for i in [2, 3, 5, 7]:
while temp % i == 0:
factors[i] = factors.get(i, 0) + 1
temp //= i
i = 11
while i*i <= temp:
while temp % i == 0:
factors[i] = factors.get(i, 0) + 1
temp //= i
i += 2
if temp > 1:
factors[temp] = 1
order = phi
for prime in factors:
test_order = order // prime
if pow(g, test_order, p) == 1:
order = test_order
while order % prime == 0:
candidate = order // prime
if pow(g, candidate, p) != 1:
break
order = candidate
return order
def solve_case():
data = input().split()
ptr = 0
p = int(data[ptr]); ptr +=1
A = [[int(data[ptr])%p, int(data[ptr+1])%p], [int(data[ptr+2])%p, int(data[ptr+3])%p]]
ptr +=4
B = [[int(data[ptr])%p, int(data[ptr+1])%p], [int(data[ptr+2])%p, int(data[ptr+3])%p]]
ptr +=4
det_a = (A[0][0] * A[1][1] - A[0][1] * A[1][0]) % p
det_b = (B[0][0] * B[1][1] - B[0][1] * B[1][0]) % p
if det_a == 0:
if det_b != 0:
print(-1)
return
zero_A = all(A[i][j] == 0 for i in range(2) for j in range(2))
if zero_A:
zero_B = all(B[i][j] == 0 for i in range(2) for j in range(2))
if zero_B:
print(1)
return
else:
print(-1)
return
non_zero_rows = []
for i in range(2):
for j in range(2):
if A[i][j] != 0:
non_zero_rows.append((i, j))
if not non_zero_rows:
print(-1)
return
k_values = []
for (i,j) in non_zero_rows:
a_ij = A[i][j]
b_ij = B[i][j]
inv_a = pow(a_ij, p-2, p)
k = (b_ij * inv_a) % p
k_values.append(k)
k = k_values[0]
if not all(kv == k for kv in k_values):
print(-1)
return
for i in range(2):
for j in range(2):
a = A[i][j]
b = B[i][j]
if (a * k) % p != b % p:
print(-1)
return
t = (A[0][0] + A[1][1]) % p
if t == 0:
if k == 0:
print(2)
return
elif k == 1:
print(1)
return
else:
print(-1)
return
x = discrete_log(t, k, p)
if x == -1:
print(-1)
return
else:
n = x + 1
if n <=0:
print(-1)
return
current_k = pow(t, x, p)
if current_k != k:
print(-1)
return
print(n)
return
else:
if det_b ==0:
print(-1)
return
log_val = discrete_log(det_a, det_b, p)
if log_val == -1:
print(-1)
return
ord_deta = order(det_a, p)
if ord_deta == -1:
print(-1)
return
t = (A[0][0] + A[1][1]) % p
d = det_a
def compute_an_bn(n, p, t, d):
if n == 1:
a_n = 1 % p
b_n = 0 % p
return (a_n, b_n)
if n == 2:
a_n = t % p
b_n = (-d) % p
return (a_n, b_n)
trans_mat = [[t % p, (-d) % p], [1, 0]]
power = n - 2
mat_power = mat_pow(trans_mat, power, p)
a_2 = t % p
a_1 = 1 % p
a_n = (mat_power[0][0] * a_2 + mat_power[0][1] * a_1) % p
a_n_minus_1 = (mat_power[1][0] * a_2 + mat_power[1][1] * a_1) % p
b_n = (-d * a_n_minus_1) % p
return (a_n, b_n)
candidates = []
if log_val <=0:
log_val += ord_deta
while log_val <=0:
log_val += ord_deta
current_n = log_val
max_candidates = 100
for i in range(max_candidates):
candidates.append(current_n)
current_n += ord_deta
if current_n > 1e18:
break
min_n = -1
for candidate in candidates:
a_n, b_n = compute_an_bn(candidate, p, t, d)
a11 = (a_n * A[0][0] + b_n) % p
a12 = (a_n * A[0][1]) % p
a21 = (a_n * A[1][0]) % p
a22 = (a_n * A[1][1] + b_n) % p
if (a11 == B[0][0] % p and a12 == B[0][1] % p and
a21 == B[1][0] % p and a22 == B[1][1] % p):
if min_n == -1 or candidate < min_n:
min_n = candidate
if min_n != -1:
print(min_n)
return
else:
print(-1)
return
solve_case()
lam6er