結果
問題 |
No.1648 Sum of Powers
|
ユーザー |
![]() |
提出日時 | 2025-04-24 12:27:11 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 2,568 bytes |
コンパイル時間 | 149 ms |
コンパイル使用メモリ | 82,648 KB |
実行使用メモリ | 78,068 KB |
最終ジャッジ日時 | 2025-04-24 12:28:23 |
合計ジャッジ時間 | 8,563 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 21 WA * 35 |
ソースコード
import math mod_val = 998244353 def modinv(a): return pow(a, mod_val - 2, mod_val) def tonelli_shanks(n, p): if n == 0: return 0 ls = pow(n, (p - 1) // 2, p) if ls != 1 and ls != 0: return -1 if ls == 0: return 0 Q = p - 1 S = 0 while Q % 2 == 0: Q //= 2 S += 1 z = 2 while pow(z, (p - 1) // 2, p) != p - 1: z += 1 c = pow(z, Q, p) x = pow(n, (Q + 1) // 2, p) t = pow(n, Q, p) m = S while t != 1: i, temp = 0, t while temp != 1 and i < m: temp = pow(temp, 2, p) i += 1 if i == m: return -1 b = pow(c, 1 << (m - i - 1), p) x = (x * b) % p t = (t * b * b) % p c = (b * b) % p m = i return x def bsgs(base, target, mod, max_k): base %= mod target %= mod if base == 0: if target == 0 and max_k >= 0: return 0 else: return -1 if target == 1: return 0 m = int(math.isqrt(max_k)) + 1 table = {} current = 1 for j in range(m): if current not in table: table[current] = j current = (current * base) % mod inv_base_m = pow(base, m, mod) inv_base_m = pow(inv_base_m, mod - 2, mod) giant = target for i in range(m): if giant in table: k = i * m + table[giant] if k <= max_k: return k giant = (giant * inv_base_m) % mod return -1 A, B, P, Q = map(int, input().split()) A = A % mod_val B = B % mod_val P = P % mod_val Q = Q % mod_val D = (A * A - 4 * B) % mod_val sqrt_D = 0 if D != 0: sqrt_D = tonelli_shanks(D, mod_val) if sqrt_D == -1: sqrt_D = 0 inv_2 = modinv(2) X = (A + sqrt_D) * inv_2 % mod_val Y = (A - sqrt_D) * inv_2 % mod_val if X == Y: X_val = X if X_val == 0: print(2) else: U = (Q * inv_2) % mod_val if X_val == 1: print(10**18) else: max_k = 10**10 - 1 k = bsgs(X_val, U, mod_val, max_k) print(k + 1 if k != -1 else -1) else: denominator = (X - Y) % mod_val inv_denominator = modinv(denominator) U = ((P - Y * Q) % mod_val) * inv_denominator % mod_val V = ((X * Q - P) % mod_val) * inv_denominator % mod_val target_b = (V * U) % mod_val max_k = 10**10 - 1 k = bsgs(X, U, mod_val, max_k) if k != -1: if pow(B, k, mod_val) == target_b: print(k + 1) else: pass else: pass