結果

問題 No.1648 Sum of Powers
ユーザー qwewe
提出日時 2025-05-14 12:51:02
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,758 bytes
コンパイル時間 386 ms
コンパイル使用メモリ 81,988 KB
実行使用メモリ 91,592 KB
最終ジャッジ日時 2025-05-14 12:51:30
合計ジャッジ時間 6,639 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 54 WA * 2
権限があれば一括ダウンロードができます

ソースコード

diff #

import math

MOD = 998244353

def multiply(m1, m2, mod):
    a = (m1[0][0] * m2[0][0] + m1[0][1] * m2[1][0]) % mod
    b = (m1[0][0] * m2[0][1] + m1[0][1] * m2[1][1]) % mod
    c = (m1[1][0] * m2[0][0] + m1[1][1] * m2[1][0]) % mod
    d = (m1[1][0] * m2[0][1] + m1[1][1] * m2[1][1]) % mod
    return [[a, b], [c, d]]

def matrix_pow(mat, power, mod):
    result = [[1, 0], [0, 1]]  # Identity matrix
    while power > 0:
        if power % 2 == 1:
            result = multiply(result, mat, mod)
        mat = multiply(mat, mat, mod)
        power //= 2
    return result

def matrix_inverse(mat, mod):
    a, b = mat[0]
    c, d = mat[1]
    det = (a * d - b * c) % mod
    if det == 0:
        return None
    det_inv = pow(det, -1, mod)
    return [
        [(d * det_inv) % mod, (-b * det_inv) % mod],
        [(-c * det_inv) % mod, (a * det_inv) % mod]
    ]

def multiply_matrix_vector(mat, vec, mod):
    a = (mat[0][0] * vec[0] + mat[0][1] * vec[1]) % mod
    b = (mat[1][0] * vec[0] + mat[1][1] * vec[1]) % mod
    return (a, b)

def solve_matrix_case(A, B, P, Q, mod):
    V0 = (A % mod, 2 % mod)
    V = (P % mod, Q % mod)
    if V == V0:
        return 10**18
    M = [[A % mod, (-B) % mod], [1, 0]]
    Minv = matrix_inverse(M, mod)
    if Minv is None:
        return -1
    m = int(math.isqrt(10**10)) + 1  # m is sqrt(1e10)
    baby_steps = {}
    current = V0
    baby_steps[current] = 0
    for i in range(1, m):
        current = multiply_matrix_vector(M, current, mod)
        if current not in baby_steps:
            baby_steps[current] = i
    Minv_power_m = matrix_pow(Minv, m, mod)
    current_giant = V
    for j in range(m):
        if current_giant in baby_steps:
            k = j * m + baby_steps[current_giant]
            if k >= 0:
                return k + 1
        current_giant = multiply_matrix_vector(Minv_power_m, current_giant, mod)
    return -1

def discrete_log(a, b, p):
    if b == 1:
        return 0
    a = a % p
    b = b % p
    if a == 0:
        if b == 0:
            return 1
        else:
            return -1
    g = math.gcd(a, p)
    if g != 1:
        if b % g != 0:
            return -1
        a_new = a // g
        b_new = b // g
        p_new = p // g
        inv_a_new = pow(a_new, -1, p_new)
        c = (b_new * inv_a_new) % p_new
        y = discrete_log(a_new, c, p_new)
        if y == -1:
            return -1
        return y + 1
    m = int(math.isqrt(p)) + 1
    table = {}
    current = 1
    for j in range(m):
        if current not in table:
            table[current] = j
        current = (current * a) % p
    a_m = pow(a, m, p)
    a_m_inv = pow(a_m, p-2, p)
    gamma = b
    for i in range(m):
        if gamma in table:
            return i * m + table[gamma]
        gamma = (gamma * a_m_inv) % p
    return -1

def solve():
    A, B, P, Q = map(int, input().split())
    A = A % MOD
    B = B % MOD
    P = P % MOD
    Q = Q % MOD
    if B == 0:
        if A == 0:
            if P == 0 and Q == 0:
                print(1000000000000000000)
            else:
                print(-1)
            return
        else:
            if (A * Q) % MOD != P:
                print(-1)
                return
            if Q == 0:
                print(-1)
                return
            x = discrete_log(A, Q, MOD)
            if x == -1:
                print(-1)
            else:
                if x >= 1:
                    print(x + 1)
                else:
                    print(-1)
    else:
        V0 = (A, 2 % MOD)
        V = (P, Q)
        if V == V0:
            print(1000000000000000000)
            return
        res = solve_matrix_case(A, B, P, Q, MOD)
        if res == -1:
            print(-1)
        else:
            print(res)

solve()
0