結果
問題 |
No.950 行列累乗
|
ユーザー |
![]() |
提出日時 | 2025-03-20 20:23:46 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 6,349 bytes |
コンパイル時間 | 320 ms |
コンパイル使用メモリ | 82,284 KB |
実行使用メモリ | 78,852 KB |
最終ジャッジ日時 | 2025-03-20 20:25:59 |
合計ジャッジ時間 | 5,537 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 40 WA * 17 |
ソースコード
import sys import math from math import sqrt, gcd def input(): return sys.stdin.read() def mat_mult(a, b, p): return [ [(a[0][0]*b[0][0] + a[0][1]*b[1][0]) % p, (a[0][0]*b[0][1] + a[0][1]*b[1][1]) % p], [(a[1][0]*b[0][0] + a[1][1]*b[1][0]) % p, (a[1][0]*b[0][1] + a[1][1]*b[1][1]) % p] ] def mat_pow(mat, power, p): result = [[1,0], [0,1]] # Identity matrix while power > 0: if power % 2 == 1: result = mat_mult(result, mat, p) mat = mat_mult(mat, mat, p) power //= 2 return result def baby_step_giant_step(a, b, p): if b == 1: return 0 if a == 0: return -1 m = int(math.isqrt(p)) + 1 table = {} current = 1 for j in range(m): if current not in table: table[current] = j current = (current * a) % p am_inv = pow(a, m * (p-2), p) gamma = b for i in range(m): if gamma in table: return i * m + table[gamma] gamma = (gamma * am_inv) % p return -1 def discrete_log(g, h, p): if g == 0: if h == 0: return 1 else: return -1 if h == 1: return 0 if g == h: return 1 return baby_step_giant_step(g, h, p) def order(g, p): if g == 0: return -1 if g % p == 0: return -1 phi = p - 1 factors = {} temp = phi for i in [2, 3, 5, 7]: while temp % i == 0: factors[i] = factors.get(i, 0) + 1 temp //= i i = 11 while i*i <= temp: while temp % i == 0: factors[i] = factors.get(i, 0) + 1 temp //= i i += 2 if temp > 1: factors[temp] = 1 order = phi for prime in factors: test_order = order // prime if pow(g, test_order, p) == 1: order = test_order while order % prime == 0: candidate = order // prime if pow(g, candidate, p) != 1: break order = candidate return order def solve_case(): data = input().split() ptr = 0 p = int(data[ptr]); ptr +=1 A = [[int(data[ptr])%p, int(data[ptr+1])%p], [int(data[ptr+2])%p, int(data[ptr+3])%p]] ptr +=4 B = [[int(data[ptr])%p, int(data[ptr+1])%p], [int(data[ptr+2])%p, int(data[ptr+3])%p]] ptr +=4 det_a = (A[0][0] * A[1][1] - A[0][1] * A[1][0]) % p det_b = (B[0][0] * B[1][1] - B[0][1] * B[1][0]) % p if det_a == 0: if det_b != 0: print(-1) return zero_A = all(A[i][j] == 0 for i in range(2) for j in range(2)) if zero_A: zero_B = all(B[i][j] == 0 for i in range(2) for j in range(2)) if zero_B: print(1) return else: print(-1) return non_zero_rows = [] for i in range(2): for j in range(2): if A[i][j] != 0: non_zero_rows.append((i, j)) if not non_zero_rows: print(-1) return k_values = [] for (i,j) in non_zero_rows: a_ij = A[i][j] b_ij = B[i][j] inv_a = pow(a_ij, p-2, p) k = (b_ij * inv_a) % p k_values.append(k) k = k_values[0] if not all(kv == k for kv in k_values): print(-1) return for i in range(2): for j in range(2): a = A[i][j] b = B[i][j] if (a * k) % p != b % p: print(-1) return t = (A[0][0] + A[1][1]) % p if t == 0: if k == 0: print(2) return elif k == 1: print(1) return else: print(-1) return x = discrete_log(t, k, p) if x == -1: print(-1) return else: n = x + 1 if n <=0: print(-1) return current_k = pow(t, x, p) if current_k != k: print(-1) return print(n) return else: if det_b ==0: print(-1) return log_val = discrete_log(det_a, det_b, p) if log_val == -1: print(-1) return ord_deta = order(det_a, p) if ord_deta == -1: print(-1) return t = (A[0][0] + A[1][1]) % p d = det_a def compute_an_bn(n, p, t, d): if n == 1: a_n = 1 % p b_n = 0 % p return (a_n, b_n) if n == 2: a_n = t % p b_n = (-d) % p return (a_n, b_n) trans_mat = [[t % p, (-d) % p], [1, 0]] power = n - 2 mat_power = mat_pow(trans_mat, power, p) a_2 = t % p a_1 = 1 % p a_n = (mat_power[0][0] * a_2 + mat_power[0][1] * a_1) % p a_n_minus_1 = (mat_power[1][0] * a_2 + mat_power[1][1] * a_1) % p b_n = (-d * a_n_minus_1) % p return (a_n, b_n) candidates = [] if log_val <=0: log_val += ord_deta while log_val <=0: log_val += ord_deta current_n = log_val max_candidates = 100 for i in range(max_candidates): candidates.append(current_n) current_n += ord_deta if current_n > 1e18: break min_n = -1 for candidate in candidates: a_n, b_n = compute_an_bn(candidate, p, t, d) a11 = (a_n * A[0][0] + b_n) % p a12 = (a_n * A[0][1]) % p a21 = (a_n * A[1][0]) % p a22 = (a_n * A[1][1] + b_n) % p if (a11 == B[0][0] % p and a12 == B[0][1] % p and a21 == B[1][0] % p and a22 == B[1][1] % p): if min_n == -1 or candidate < min_n: min_n = candidate if min_n != -1: print(min_n) return else: print(-1) return solve_case()