結果

問題 No.1648 Sum of Powers
ユーザー gew1fw
提出日時 2025-06-12 19:15:13
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,838 bytes
コンパイル時間 174 ms
コンパイル使用メモリ 82,100 KB
実行使用メモリ 101,408 KB
最終ジャッジ日時 2025-06-12 19:15:33
合計ジャッジ時間 9,309 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 29 WA * 27
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353

def main():
    import sys
    A, B, P, Q = map(int, sys.stdin.readline().split())
    
    if B == 0:
        if A == 0:
            if P == 0 and Q == 0:
                print(2)
            else:
                pass
            return
        else:
            if Q == 0:
                if P == 0:
                    print(2)
                else:
                    pass
                return
            inv_Q = pow(Q, MOD-2, MOD)
            expected_A = (P * inv_Q) % MOD
            if expected_A != A:
                return
            X = A
            C = Q
            if X == 0:
                if C == 0:
                    print(1000000000000000000)
                else:
                    pass
                return
            m = int(1e5)
            table = {}
            current = 1
            for j in range(m):
                if current not in table:
                    table[current] = j
                current = (current * X) % MOD
            X_inv_m = pow(X, m * (MOD-2), MOD)
            giant = C
            found = False
            for i in range(m+1):
                if giant in table:
                    k = i * m + table[giant]
                    N = k + 1
                    if N >= 2 and N <= 1e18:
                        print(N)
                        found = True
                        break
                giant = (giant * X_inv_m) % MOD
            if not found:
                print(1000000000000000000)
            return
    else:
        inv_B = pow(B, MOD-2, MOD)
        m = 10**5
        forward = {}
        s_prev, s_curr = 2, A
        forward[(s_prev, s_curr)] = 0
        prev_pair = None
        for n in range(1, m+1):
            s_next = (A * s_curr - B * s_prev) % MOD
            new_s_prev, new_s_curr = s_curr, s_next
            current_pair = (new_s_prev, new_s_curr)
            if current_pair == (s_prev, s_curr):
                if new_s_prev == new_s_curr and new_s_prev == s_prev:
                    print(1000000000000000000)
                    return
            if current_pair not in forward:
                forward[current_pair] = n
            s_prev, s_curr = new_s_prev, new_s_curr
            prev_pair = current_pair
        
        current_s_prev, current_s_curr = Q, P
        for t in range(m+1):
            current_pair = (current_s_prev, current_s_curr)
            if current_pair in forward:
                k = forward[current_pair]
                N = k + t + 1
                if 2 <= N <= 10**18:
                    print(N)
                    return
            s_prev_prev = (A * current_s_prev - current_s_curr) * inv_B % MOD
            current_s_prev, current_s_curr = s_prev_prev, current_s_prev
        
        print(1000000000000000000)
        return

if __name__ == "__main__":
    main()
0