結果

問題 No.1648 Sum of Powers
ユーザー qwewe
提出日時 2025-04-24 12:27:11
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,568 bytes
コンパイル時間 149 ms
コンパイル使用メモリ 82,648 KB
実行使用メモリ 78,068 KB
最終ジャッジ日時 2025-04-24 12:28:23
合計ジャッジ時間 8,563 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 21 WA * 35
権限があれば一括ダウンロードができます

ソースコード

diff #

import math

mod_val = 998244353

def modinv(a):
    return pow(a, mod_val - 2, mod_val)

def tonelli_shanks(n, p):
    if n == 0:
        return 0
    ls = pow(n, (p - 1) // 2, p)
    if ls != 1 and ls != 0:
        return -1
    if ls == 0:
        return 0
    Q = p - 1
    S = 0
    while Q % 2 == 0:
        Q //= 2
        S += 1
    z = 2
    while pow(z, (p - 1) // 2, p) != p - 1:
        z += 1
    c = pow(z, Q, p)
    x = pow(n, (Q + 1) // 2, p)
    t = pow(n, Q, p)
    m = S
    while t != 1:
        i, temp = 0, t
        while temp != 1 and i < m:
            temp = pow(temp, 2, p)
            i += 1
        if i == m:
            return -1
        b = pow(c, 1 << (m - i - 1), p)
        x = (x * b) % p
        t = (t * b * b) % p
        c = (b * b) % p
        m = i
    return x

def bsgs(base, target, mod, max_k):
    base %= mod
    target %= mod
    if base == 0:
        if target == 0 and max_k >= 0:
            return 0
        else:
            return -1
    if target == 1:
        return 0
    m = int(math.isqrt(max_k)) + 1
    table = {}
    current = 1
    for j in range(m):
        if current not in table:
            table[current] = j
        current = (current * base) % mod
    inv_base_m = pow(base, m, mod)
    inv_base_m = pow(inv_base_m, mod - 2, mod)
    giant = target
    for i in range(m):
        if giant in table:
            k = i * m + table[giant]
            if k <= max_k:
                return k
        giant = (giant * inv_base_m) % mod
    return -1

A, B, P, Q = map(int, input().split())

A = A % mod_val
B = B % mod_val
P = P % mod_val
Q = Q % mod_val

D = (A * A - 4 * B) % mod_val

sqrt_D = 0
if D != 0:
    sqrt_D = tonelli_shanks(D, mod_val)
    if sqrt_D == -1:
        sqrt_D = 0

inv_2 = modinv(2)
X = (A + sqrt_D) * inv_2 % mod_val
Y = (A - sqrt_D) * inv_2 % mod_val

if X == Y:
    X_val = X
    if X_val == 0:
        print(2)
    else:
        U = (Q * inv_2) % mod_val
        if X_val == 1:
            print(10**18)
        else:
            max_k = 10**10 - 1
            k = bsgs(X_val, U, mod_val, max_k)
            print(k + 1 if k != -1 else -1)
else:
    denominator = (X - Y) % mod_val
    inv_denominator = modinv(denominator)
    U = ((P - Y * Q) % mod_val) * inv_denominator % mod_val
    V = ((X * Q - P) % mod_val) * inv_denominator % mod_val
    target_b = (V * U) % mod_val
    max_k = 10**10 - 1
    k = bsgs(X, U, mod_val, max_k)
    if k != -1:
        if pow(B, k, mod_val) == target_b:
            print(k + 1)
        else:
            pass
    else:
        pass
0