結果

問題 No.1648 Sum of Powers
ユーザー qwewe
提出日時 2025-05-14 12:51:03
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,686 bytes
コンパイル時間 231 ms
コンパイル使用メモリ 82,852 KB
実行使用メモリ 95,412 KB
最終ジャッジ日時 2025-05-14 12:51:49
合計ジャッジ時間 10,030 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 3 WA * 53
権限があれば一括ダウンロードができます

ソースコード

diff #

import math

MOD = 998244353

def modinv(a):
    return pow(a, MOD-2, MOD)

def mod_sqrt(a):
    if a == 0:
        return 0
    if pow(a, (MOD-1)//2, MOD) != 1:
        return None
    return pow(a, (MOD + 1) // 4, MOD)

def baby_step_giant_step(a, b, mod, max_k):
    if b == 1:
        return 0
    m = int(math.isqrt(max_k)) + 1
    table = {}
    current = 1
    for j in range(m):
        if current not in table:
            table[current] = j
        current = current * a % mod
    am = pow(a, m, mod)
    am_inv = pow(am, mod-2, mod)
    gamma = b
    for i in range(m):
        if gamma in table:
            res = i * m + table[gamma]
            if res > max_k:
                return None
            return res
        gamma = gamma * am_inv % mod
    return None

def solve():
    A, B, P, Q = map(int, input().split())
    A %= MOD
    B %= MOD
    P %= MOD
    Q %= MOD
    D = (A * A - 4 * B) % MOD
    if D < 0:
        D += MOD
    if D == 0:
        X = (A * modinv(2)) % MOD
        Y = X
        inv2 = modinv(2)
        target_Q = (Q * inv2) % MOD
        target_P = (P * inv2) % MOD
        if (X * target_Q) % MOD != target_P:
            return None
        if X == 1:
            return 10**18
        else:
            k = baby_step_giant_step(X, target_Q, MOD, 10**10 - 1)
            if k is not None:
                return k + 1
            else:
                return None
    else:
        sqrt_D = mod_sqrt(D)
        if sqrt_D is None:
            return None
        sqrt_D = sqrt_D % MOD
        inv2 = modinv(2)
        X = ((A + sqrt_D) * inv2) % MOD
        Y = ((A - sqrt_D) * inv2) % MOD
        denom = (X - Y) % MOD
        if denom == 0:
            return None
        denom_inv = modinv(denom)
        U = (P - Y * Q) % MOD
        U = U * denom_inv % MOD
        V = (Q - U) % MOD
        current_P = (X * U + Y * V) % MOD
        current_Q = (U + V) % MOD
        if current_P != P or current_Q != Q:
            return None
        max_k = 10**10 - 1
        k_x = baby_step_giant_step(X, U, MOD, max_k)
        if k_x is not None:
            if pow(Y, k_x, MOD) == V:
                return k_x + 1
        k_y = baby_step_giant_step(Y, V, MOD, max_k)
        if k_y is not None:
            if pow(X, k_y, MOD) == U:
                return k_y + 1
        if X == 0:
            if U == 0:
                if pow(Y, 0, MOD) == V:
                    return 1
                else:
                    return None
        if Y == 0:
            if V == 0:
                if pow(X, 0, MOD) == U:
                    return 1
                else:
                    return None
        return None

result = solve()
print(result)
0