結果
| 問題 |
No.1648 Sum of Powers
|
| コンテスト | |
| ユーザー |
qwewe
|
| 提出日時 | 2025-05-14 12:51:03 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 2,686 bytes |
| コンパイル時間 | 231 ms |
| コンパイル使用メモリ | 82,852 KB |
| 実行使用メモリ | 95,412 KB |
| 最終ジャッジ日時 | 2025-05-14 12:51:49 |
| 合計ジャッジ時間 | 10,030 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 3 WA * 53 |
ソースコード
import math
MOD = 998244353
def modinv(a):
return pow(a, MOD-2, MOD)
def mod_sqrt(a):
if a == 0:
return 0
if pow(a, (MOD-1)//2, MOD) != 1:
return None
return pow(a, (MOD + 1) // 4, MOD)
def baby_step_giant_step(a, b, mod, max_k):
if b == 1:
return 0
m = int(math.isqrt(max_k)) + 1
table = {}
current = 1
for j in range(m):
if current not in table:
table[current] = j
current = current * a % mod
am = pow(a, m, mod)
am_inv = pow(am, mod-2, mod)
gamma = b
for i in range(m):
if gamma in table:
res = i * m + table[gamma]
if res > max_k:
return None
return res
gamma = gamma * am_inv % mod
return None
def solve():
A, B, P, Q = map(int, input().split())
A %= MOD
B %= MOD
P %= MOD
Q %= MOD
D = (A * A - 4 * B) % MOD
if D < 0:
D += MOD
if D == 0:
X = (A * modinv(2)) % MOD
Y = X
inv2 = modinv(2)
target_Q = (Q * inv2) % MOD
target_P = (P * inv2) % MOD
if (X * target_Q) % MOD != target_P:
return None
if X == 1:
return 10**18
else:
k = baby_step_giant_step(X, target_Q, MOD, 10**10 - 1)
if k is not None:
return k + 1
else:
return None
else:
sqrt_D = mod_sqrt(D)
if sqrt_D is None:
return None
sqrt_D = sqrt_D % MOD
inv2 = modinv(2)
X = ((A + sqrt_D) * inv2) % MOD
Y = ((A - sqrt_D) * inv2) % MOD
denom = (X - Y) % MOD
if denom == 0:
return None
denom_inv = modinv(denom)
U = (P - Y * Q) % MOD
U = U * denom_inv % MOD
V = (Q - U) % MOD
current_P = (X * U + Y * V) % MOD
current_Q = (U + V) % MOD
if current_P != P or current_Q != Q:
return None
max_k = 10**10 - 1
k_x = baby_step_giant_step(X, U, MOD, max_k)
if k_x is not None:
if pow(Y, k_x, MOD) == V:
return k_x + 1
k_y = baby_step_giant_step(Y, V, MOD, max_k)
if k_y is not None:
if pow(X, k_y, MOD) == U:
return k_y + 1
if X == 0:
if U == 0:
if pow(Y, 0, MOD) == V:
return 1
else:
return None
if Y == 0:
if V == 0:
if pow(X, 0, MOD) == U:
return 1
else:
return None
return None
result = solve()
print(result)
qwewe