結果
問題 |
No.1648 Sum of Powers
|
ユーザー |
![]() |
提出日時 | 2025-05-14 12:51:03 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 2,686 bytes |
コンパイル時間 | 231 ms |
コンパイル使用メモリ | 82,852 KB |
実行使用メモリ | 95,412 KB |
最終ジャッジ日時 | 2025-05-14 12:51:49 |
合計ジャッジ時間 | 10,030 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 3 WA * 53 |
ソースコード
import math MOD = 998244353 def modinv(a): return pow(a, MOD-2, MOD) def mod_sqrt(a): if a == 0: return 0 if pow(a, (MOD-1)//2, MOD) != 1: return None return pow(a, (MOD + 1) // 4, MOD) def baby_step_giant_step(a, b, mod, max_k): if b == 1: return 0 m = int(math.isqrt(max_k)) + 1 table = {} current = 1 for j in range(m): if current not in table: table[current] = j current = current * a % mod am = pow(a, m, mod) am_inv = pow(am, mod-2, mod) gamma = b for i in range(m): if gamma in table: res = i * m + table[gamma] if res > max_k: return None return res gamma = gamma * am_inv % mod return None def solve(): A, B, P, Q = map(int, input().split()) A %= MOD B %= MOD P %= MOD Q %= MOD D = (A * A - 4 * B) % MOD if D < 0: D += MOD if D == 0: X = (A * modinv(2)) % MOD Y = X inv2 = modinv(2) target_Q = (Q * inv2) % MOD target_P = (P * inv2) % MOD if (X * target_Q) % MOD != target_P: return None if X == 1: return 10**18 else: k = baby_step_giant_step(X, target_Q, MOD, 10**10 - 1) if k is not None: return k + 1 else: return None else: sqrt_D = mod_sqrt(D) if sqrt_D is None: return None sqrt_D = sqrt_D % MOD inv2 = modinv(2) X = ((A + sqrt_D) * inv2) % MOD Y = ((A - sqrt_D) * inv2) % MOD denom = (X - Y) % MOD if denom == 0: return None denom_inv = modinv(denom) U = (P - Y * Q) % MOD U = U * denom_inv % MOD V = (Q - U) % MOD current_P = (X * U + Y * V) % MOD current_Q = (U + V) % MOD if current_P != P or current_Q != Q: return None max_k = 10**10 - 1 k_x = baby_step_giant_step(X, U, MOD, max_k) if k_x is not None: if pow(Y, k_x, MOD) == V: return k_x + 1 k_y = baby_step_giant_step(Y, V, MOD, max_k) if k_y is not None: if pow(X, k_y, MOD) == U: return k_y + 1 if X == 0: if U == 0: if pow(Y, 0, MOD) == V: return 1 else: return None if Y == 0: if V == 0: if pow(X, 0, MOD) == U: return 1 else: return None return None result = solve() print(result)