結果
| 問題 |
No.1648 Sum of Powers
|
| コンテスト | |
| ユーザー |
qwewe
|
| 提出日時 | 2025-05-14 12:51:02 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 2,568 bytes |
| コンパイル時間 | 190 ms |
| コンパイル使用メモリ | 82,172 KB |
| 実行使用メモリ | 78,236 KB |
| 最終ジャッジ日時 | 2025-05-14 12:51:42 |
| 合計ジャッジ時間 | 8,567 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 21 WA * 35 |
ソースコード
import math
mod_val = 998244353
def modinv(a):
return pow(a, mod_val - 2, mod_val)
def tonelli_shanks(n, p):
if n == 0:
return 0
ls = pow(n, (p - 1) // 2, p)
if ls != 1 and ls != 0:
return -1
if ls == 0:
return 0
Q = p - 1
S = 0
while Q % 2 == 0:
Q //= 2
S += 1
z = 2
while pow(z, (p - 1) // 2, p) != p - 1:
z += 1
c = pow(z, Q, p)
x = pow(n, (Q + 1) // 2, p)
t = pow(n, Q, p)
m = S
while t != 1:
i, temp = 0, t
while temp != 1 and i < m:
temp = pow(temp, 2, p)
i += 1
if i == m:
return -1
b = pow(c, 1 << (m - i - 1), p)
x = (x * b) % p
t = (t * b * b) % p
c = (b * b) % p
m = i
return x
def bsgs(base, target, mod, max_k):
base %= mod
target %= mod
if base == 0:
if target == 0 and max_k >= 0:
return 0
else:
return -1
if target == 1:
return 0
m = int(math.isqrt(max_k)) + 1
table = {}
current = 1
for j in range(m):
if current not in table:
table[current] = j
current = (current * base) % mod
inv_base_m = pow(base, m, mod)
inv_base_m = pow(inv_base_m, mod - 2, mod)
giant = target
for i in range(m):
if giant in table:
k = i * m + table[giant]
if k <= max_k:
return k
giant = (giant * inv_base_m) % mod
return -1
A, B, P, Q = map(int, input().split())
A = A % mod_val
B = B % mod_val
P = P % mod_val
Q = Q % mod_val
D = (A * A - 4 * B) % mod_val
sqrt_D = 0
if D != 0:
sqrt_D = tonelli_shanks(D, mod_val)
if sqrt_D == -1:
sqrt_D = 0
inv_2 = modinv(2)
X = (A + sqrt_D) * inv_2 % mod_val
Y = (A - sqrt_D) * inv_2 % mod_val
if X == Y:
X_val = X
if X_val == 0:
print(2)
else:
U = (Q * inv_2) % mod_val
if X_val == 1:
print(10**18)
else:
max_k = 10**10 - 1
k = bsgs(X_val, U, mod_val, max_k)
print(k + 1 if k != -1 else -1)
else:
denominator = (X - Y) % mod_val
inv_denominator = modinv(denominator)
U = ((P - Y * Q) % mod_val) * inv_denominator % mod_val
V = ((X * Q - P) % mod_val) * inv_denominator % mod_val
target_b = (V * U) % mod_val
max_k = 10**10 - 1
k = bsgs(X, U, mod_val, max_k)
if k != -1:
if pow(B, k, mod_val) == target_b:
print(k + 1)
else:
pass
else:
pass
qwewe