結果

問題 No.1648 Sum of Powers
ユーザー lam6er
提出日時 2025-04-09 21:05:37
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,897 bytes
コンパイル時間 154 ms
コンパイル使用メモリ 82,168 KB
実行使用メモリ 79,532 KB
最終ジャッジ日時 2025-04-09 21:07:42
合計ジャッジ時間 6,474 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 48 WA * 8
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 998244353

def discrete_log(a, b, mod):
    if b == 1:
        return 0
    a %= mod
    b %= mod
    if a == 0:
        if b == 0:
            return 1  # Only valid when b ==0 and a ==0, which is handled earlier
        else:
            return -1
    m = int(mod**0.5) + 1
    table = {}
    current = 1
    for j in range(m):
        if current not in table:
            table[current] = j
        current = current * a % mod
    am = pow(a, m, mod)
    am_inv = pow(am, mod-2, mod)
    gamma = b
    for i in range(m):
        if gamma in table:
            return i * m + table[gamma]
        gamma = gamma * am_inv % mod
    return -1

def solve_matrix_case(A, B, P, Q, MOD):
    inv_B = pow(B, MOD-2, MOD)
    R = [
        [0, 1],
        [(MOD - inv_B) % MOD, (A * inv_B) % MOD]
    ]
    
    def multiply_vec(mat, vec):
        x = (mat[0][0] * vec[0] + mat[0][1] * vec[1]) % MOD
        y = (mat[1][0] * vec[0] + mat[1][1] * vec[1]) % MOD
        return (x, y)
    
    m = int(MOD ** 0.5) + 1
    baby_map = {}
    current_vec = (P % MOD, Q % MOD)
    for r in range(m):
        if current_vec not in baby_map:
            baby_map[current_vec] = r
        x, y = current_vec
        new_x = y
        new_y = (-inv_B * x + A * inv_B * y) % MOD
        current_vec = (new_x, new_y)
    
    # Prepare R^m
    mat = [[1 if i == j else 0 for j in range(2)] for i in range(2)]
    current = R
    power = m
    while power > 0:
        if power % 2 == 1:
            mat = [
                [
                    (mat[0][0] * current[0][0] + mat[0][1] * current[1][0]) % MOD,
                    (mat[0][0] * current[0][1] + mat[0][1] * current[1][1]) % MOD
                ],
                [
                    (mat[1][0] * current[0][0] + mat[1][1] * current[1][0]) % MOD,
                    (mat[1][0] * current[0][1] + mat[1][1] * current[1][1]) % MOD
                ]
            ]
        current = [
            [
                (current[0][0] * current[0][0] + current[0][1] * current[1][0]) % MOD,
                (current[0][0] * current[0][1] + current[0][1] * current[1][1]) % MOD
            ],
            [
                (current[1][0] * current[0][0] + current[1][1] * current[1][0]) % MOD,
                (current[1][0] * current[0][1] + current[1][1] * current[1][1]) % MOD
            ]
        ]
        power //= 2
    
    a, b = mat[0][0], mat[0][1]
    c, d = mat[1][0], mat[1][1]
    det = (a * d - b * c) % MOD
    if det == 0:
        return -1
    inv_det = pow(det, MOD-2, MOD)
    inv_mat = [
        [(d * inv_det) % MOD, (-b * inv_det) % MOD],
        [(-c * inv_det) % MOD, (a * inv_det) % MOD]
    ]
    
    current_giant = (A % MOD, 2 % MOD)
    for q in range(m):
        if current_giant in baby_map:
            r = baby_map[current_giant]
            return r + q * m
        current_giant = multiply_vec(inv_mat, current_giant)
    return -1

A, B, P, Q = map(int, sys.stdin.readline().split())

A_mod = A % MOD
B_mod = B % MOD
P_mod = P % MOD
Q_mod = Q % MOD

if B_mod == 0:
    A_mod = A % MOD
    if A_mod == 0:
        if P_mod == 0 and Q_mod == 0:
            print(1000000000000000000)
        else:
            pass
    else:
        if (A_mod * Q_mod) % MOD != P_mod:
            pass
        else:
            if Q_mod == 0:
                print(1000000000000000000)
            else:
                x = discrete_log(A_mod, Q_mod, MOD)
                if x == -1:
                    pass
                else:
                    N = x + 1
                    print(N)
else:
    K = solve_matrix_case(A_mod, B_mod, P_mod, Q_mod, MOD)
    if K == -1:
        pass
    else:
        N_candidate = K + 1
        if N_candidate >= 2:
            print(N_candidate)
        else:
            if (P_mod, Q_mod) == (A_mod, 2 % MOD):
                print(1000000000000000000)
            else:
                pass
0