結果

問題 No.1648 Sum of Powers
ユーザー lam6er
提出日時 2025-03-20 19:03:02
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,096 bytes
コンパイル時間 183 ms
コンパイル使用メモリ 82,776 KB
実行使用メモリ 75,664 KB
最終ジャッジ日時 2025-03-20 19:03:30
合計ジャッジ時間 14,216 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 20 WA * 25 TLE * 1 -- * 10
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def main():
    A, B, P, Q = map(int, input().split())
    A %= MOD
    B %= MOD
    P %= MOD
    Q %= MOD

    if (A * A) % MOD == (4 * B) % MOD:
        print(10**18)
        return

    if B == 0:
        if A == 0:
            if Q == 0 and P == 0:
                print(10**18)
            else:
                pass
            return
        else:
            if (A * Q) % MOD != P % MOD:
                pass
            else:
                target = Q
                a = A
                m = int(MOD**0.5) + 1

                table = {}
                current = 1
                for j in range(m):
                    if current not in table:
                        table[current] = j
                    current = current * a % MOD

                am_inv = pow(a, m * (MOD-2), MOD)
                gamma = target
                answer = -1
                for i in range(m):
                    if gamma in table:
                        answer = i * m + table[gamma]
                        break
                    gamma = gamma * am_inv % MOD
                if answer != -1:
                    N = answer + 1
                    if N < 2:
                        N += (2 - N + m - 1) // m * m
                        while True:
                            if pow(a, N-1, MOD) == Q:
                                break
                            N += m
                    print(N if N <= 10**18 else 10**18)
                else:
                    print(10**18)
                return
    else:
        inv_B = pow(B, MOD-2, MOD)
        target_s1 = A % MOD
        target_s0 = 2 % MOD
        current = P
        prev = Q
        steps = 0
        max_steps = 2 * (10**6)
        while steps <= max_steps:
            if current == target_s1 and prev == target_s0:
                print(steps + 1)
                return
            new_prev = ( (A * prev - current) * inv_B ) % MOD
            new_current = prev
            current, prev = new_current, new_prev
            steps += 1
        print(10**18)

if __name__ == '__main__':
    main()
0