結果
| 問題 |
No.1648 Sum of Powers
|
| コンテスト | |
| ユーザー |
qwewe
|
| 提出日時 | 2025-05-14 12:51:02 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 3,758 bytes |
| コンパイル時間 | 386 ms |
| コンパイル使用メモリ | 81,988 KB |
| 実行使用メモリ | 91,592 KB |
| 最終ジャッジ日時 | 2025-05-14 12:51:30 |
| 合計ジャッジ時間 | 6,639 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 54 WA * 2 |
ソースコード
import math
MOD = 998244353
def multiply(m1, m2, mod):
a = (m1[0][0] * m2[0][0] + m1[0][1] * m2[1][0]) % mod
b = (m1[0][0] * m2[0][1] + m1[0][1] * m2[1][1]) % mod
c = (m1[1][0] * m2[0][0] + m1[1][1] * m2[1][0]) % mod
d = (m1[1][0] * m2[0][1] + m1[1][1] * m2[1][1]) % mod
return [[a, b], [c, d]]
def matrix_pow(mat, power, mod):
result = [[1, 0], [0, 1]] # Identity matrix
while power > 0:
if power % 2 == 1:
result = multiply(result, mat, mod)
mat = multiply(mat, mat, mod)
power //= 2
return result
def matrix_inverse(mat, mod):
a, b = mat[0]
c, d = mat[1]
det = (a * d - b * c) % mod
if det == 0:
return None
det_inv = pow(det, -1, mod)
return [
[(d * det_inv) % mod, (-b * det_inv) % mod],
[(-c * det_inv) % mod, (a * det_inv) % mod]
]
def multiply_matrix_vector(mat, vec, mod):
a = (mat[0][0] * vec[0] + mat[0][1] * vec[1]) % mod
b = (mat[1][0] * vec[0] + mat[1][1] * vec[1]) % mod
return (a, b)
def solve_matrix_case(A, B, P, Q, mod):
V0 = (A % mod, 2 % mod)
V = (P % mod, Q % mod)
if V == V0:
return 10**18
M = [[A % mod, (-B) % mod], [1, 0]]
Minv = matrix_inverse(M, mod)
if Minv is None:
return -1
m = int(math.isqrt(10**10)) + 1 # m is sqrt(1e10)
baby_steps = {}
current = V0
baby_steps[current] = 0
for i in range(1, m):
current = multiply_matrix_vector(M, current, mod)
if current not in baby_steps:
baby_steps[current] = i
Minv_power_m = matrix_pow(Minv, m, mod)
current_giant = V
for j in range(m):
if current_giant in baby_steps:
k = j * m + baby_steps[current_giant]
if k >= 0:
return k + 1
current_giant = multiply_matrix_vector(Minv_power_m, current_giant, mod)
return -1
def discrete_log(a, b, p):
if b == 1:
return 0
a = a % p
b = b % p
if a == 0:
if b == 0:
return 1
else:
return -1
g = math.gcd(a, p)
if g != 1:
if b % g != 0:
return -1
a_new = a // g
b_new = b // g
p_new = p // g
inv_a_new = pow(a_new, -1, p_new)
c = (b_new * inv_a_new) % p_new
y = discrete_log(a_new, c, p_new)
if y == -1:
return -1
return y + 1
m = int(math.isqrt(p)) + 1
table = {}
current = 1
for j in range(m):
if current not in table:
table[current] = j
current = (current * a) % p
a_m = pow(a, m, p)
a_m_inv = pow(a_m, p-2, p)
gamma = b
for i in range(m):
if gamma in table:
return i * m + table[gamma]
gamma = (gamma * a_m_inv) % p
return -1
def solve():
A, B, P, Q = map(int, input().split())
A = A % MOD
B = B % MOD
P = P % MOD
Q = Q % MOD
if B == 0:
if A == 0:
if P == 0 and Q == 0:
print(1000000000000000000)
else:
print(-1)
return
else:
if (A * Q) % MOD != P:
print(-1)
return
if Q == 0:
print(-1)
return
x = discrete_log(A, Q, MOD)
if x == -1:
print(-1)
else:
if x >= 1:
print(x + 1)
else:
print(-1)
else:
V0 = (A, 2 % MOD)
V = (P, Q)
if V == V0:
print(1000000000000000000)
return
res = solve_matrix_case(A, B, P, Q, MOD)
if res == -1:
print(-1)
else:
print(res)
solve()
qwewe