結果
問題 |
No.950 行列累乗
|
ユーザー |
![]() |
提出日時 | 2025-06-12 21:37:38 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 4,241 bytes |
コンパイル時間 | 197 ms |
コンパイル使用メモリ | 82,200 KB |
実行使用メモリ | 141,060 KB |
最終ジャッジ日時 | 2025-06-12 21:40:16 |
合計ジャッジ時間 | 7,333 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 30 WA * 27 |
ソースコード
import sys from math import gcd from random import randint MOD = 10**18 + 3 def modinv(a, mod): g, x, y = extended_gcd(a, mod) if g != 1: return None else: return x % mod def extended_gcd(a, b): if a == 0: return (b, 0, 1) else: g, y, x = extended_gcd(b % a, a) return (g, x - (b // a) * y, y) def matrix_mult(m1, m2, mod): res = [[0]*len(m2[0]) for _ in range(len(m1))] for i in range(len(m1)): for k in range(len(m2)): if m1[i][k] == 0: continue for j in range(len(m2[0])): res[i][j] = (res[i][j] + m1[i][k] * m2[k][j]) % mod return res def matrix_pow(mat, power, mod): result = [[1 if i == j else 0 for j in range(len(mat))] for i in range(len(mat))] while power > 0: if power % 2 == 1: result = matrix_mult(result, mat, mod) mat = matrix_mult(mat, mat, mod) power //= 2 return result def solve_discrete_log(a, b, mod): if b == 1: return 0 if a == 0: if b == 0: return 1 else: return -1 if mod == 1: return 0 if gcd(a, mod) != 1: if b % gcd(a, mod) != 0: return -1 a //= gcd(a, mod) b //= gcd(a, mod) mod //= gcd(a, mod) max_d = int(mod ** 0.5) + 1 value = {} current = 1 for i in range(max_d): if current not in value: value[current] = i current = (current * a) % mod a_inv = modinv(a, mod) if a_inv is None: return -1 gamma = pow(a_inv, max_d, mod) current = b % mod for i in range(max_d): if current in value: return i * max_d + value[current] current = (current * gamma) % mod return -1 def main(): p = int(sys.stdin.readline()) A = [] for _ in range(2): A.append(list(map(int, sys.stdin.readline().split()))) B = [] for _ in range(2): B.append(list(map(int, sys.stdin.readline().split()))) a = None valid = True A12 = A[0][1] % p B12 = B[0][1] % p A21 = A[1][0] % p B21 = B[1][0] % p if A12 != 0: a_A12 = B12 * modinv(A12, p) % p a = a_A12 if A21 != 0: a_A21 = B21 * modinv(A21, p) % p if a is None: a = a_A21 else: if a != a_A21: print(-1) return if a is None: a = 0 else: if A12 != 0: if (a * A12) % p != B12: print(-1) return if A21 != 0: if (a * A21) % p != B21: print(-1) return b = (B[0][0] - a * A[0][0]) % p if (a * A[1][1] + b) % p != B[1][1]: print(-1) return t = (A[0][0] + A[1][1]) % p d = (A[0][0] * A[1][1] - A[0][1] * A[1][0]) % p if d == 0: if b != 0: print(-1) return current_a = 1 current_b = 0 if t == 0: if a != 0: print(-1) return else: print(1) return if t == 1: if a == current_a: print(1) return else: print(-1) return n_minus_1 = solve_discrete_log(t, a, p) if n_minus_1 == -1: print(-1) return else: print(n_minus_1 + 1) return else: M = [[t, -d], [1, 0]] a1 = 1 a2 = t * a1 + 0 a2_mod = a2 % p if a2 == a and 0 == b: print(2) return a_list = [a1, a2_mod] n = 2 max_tries = 1000000 found = False while max_tries > 0: n += 1 a_next = (t * a_list[-1] - d * a_list[-2]) % p a_list.append(a_next) b_next = (-d * a_list[-2]) % p if a_next == a and b_next == b: found = True break max_tries -= 1 if found: print(n) return else: print(-1) return if __name__ == '__main__': main()