結果

問題 No.981 一般冪乗根
ユーザー lam6er
提出日時 2025-04-15 22:31:48
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,853 bytes
コンパイル時間 372 ms
コンパイル使用メモリ 81,776 KB
実行使用メモリ 77,256 KB
最終ジャッジ日時 2025-04-15 22:34:20
合計ジャッジ時間 19,444 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other TLE * 1 -- * 43
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math
import random

def input():
    return sys.stdin.read()

def is_prime(n):
    if n < 2:
        return False
    for p in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]:
        if n % p == 0:
            return n == p
    d = n - 1
    s = 0
    while d % 2 == 0:
        d //= 2
        s += 1
    for a in [2, 325, 9375, 28178, 450775, 9780504, 1795265022]:
        if a >= n:
            continue
        x = pow(a, d, n)
        if x == 1 or x == n - 1:
            continue
        for _ in range(s - 1):
            x = pow(x, 2, n)
            if x == n - 1:
                break
        else:
            return False
    return True

def pow_mod(a, b, p):
    return pow(a, b, p)

def gcd(a, b):
    while b:
        a, b = b, a % b
    return a

def exgcd(a, b):
    if b == 0:
        return a, 1, 0
    g, x, y = exgcd(b, a % b)
    return g, y, x - (a // b) * y

def modinv(a, m):
    g, x, y = exgcd(a, m)
    if g != 1:
        return None
    return x % m

def tonelli_shanks(n, p):
    if pow(n, (p - 1) // 2, p) != 1:
        return None
    if n == 0:
        return 0
    if p == 2:
        return n
    if p % 4 == 3:
        x = pow(n, (p + 1) // 4, p)
        return x
    Q = p - 1
    S = 0
    while Q % 2 == 0:
        Q //= 2
        S += 1
    z = 2
    while pow(z, (p - 1) // 2, p) != p - 1:
        z += 1
    c = pow(z, Q, p)
    x = pow(n, (Q + 1) // 2, p)
    t = pow(n, Q, p)
    m = S
    while t != 1:
        i, temp = 0, t
        while temp != 1 and i < m:
            temp = pow(temp, 2, p)
            i += 1
        if i == m:
            return None
        b = pow(c, 1 << (m - i - 1), p)
        x = (x * b) % p
        t = (t * b * b) % p
        c = (b * b) % p
        m = i
    return x

def amm(a, k, p):
    if a == 0:
        return 0
    if k == 0:
        return 1 if a == 1 else None
    d = gcd(k, p-1)
    if pow(a, (p-1) // d, p) != 1:
        return None
    a = a % p
    if d == 1:
        return pow(a, modinv(k, p-1), p)
    k //= d
    pd = (p-1) // d
    g = 2
    while pow(g, pd, p) == 1:
        g += 1
    g = pow(g, pd, p)
    q = p-1
    t = 0
    while q % d == 0:
        q //= d
        t += 1
    s = 0
    qq = q
    while qq % d == 0:
        qq //= d
        s += 1
    inv_kk = modinv(k, d)
    if inv_kk is None:
        return None
    alpha = (inv_kk) % d
    def work(pi, ei, a):
        gamma = pow(g, q // (pi**ei), p)
        h = 1
        for _ in range(ei):
            e = ei - _
            d_i = pow(pi, e, p)
            b = pow(a, q // d_i, p)
            k = 0
            while pow(gamma, k, p) != b:
                k += 1
            h = h * pow(gamma, k * (pi**_), p) % p
            a = a * pow(gamma, -k * (pi**e), p) % p
        return h
    a0 = pow(a, modinv(d // (d), pd), p)
    x = 1
    for prime, exp in factor(d):
        d_prime = prime**exp
        a_prime = pow(a0, pd // d_prime, p)
        if pow(g, pd // prime, p) == 1:
            return None
        x_prime = work(prime, exp, a_prime)
        x = x * x_prime % p
    x = a0 * x % p
    x = pow(x, alpha, p)
    return x

def factor(n):
    n_ = n
    res = []
    i = 2
    while i * i <= n_:
        if n_ % i == 0:
            cnt = 0
            while n_ % i == 0:
                cnt += 1
                n_ //= i
            res.append((i, cnt))
        i += 1
    if n_ != 1:
        res.append((n_, 1))
    return res

def solve():
    data = input().split()
    T = int(data[0])
    idx = 1
    for _ in range(T):
        p = int(data[idx])
        k = int(data[idx+1])
        a = int(data[idx+2])
        idx +=3
        if a == 0:
            print(0)
            continue
        d = gcd(k, p-1)
        m = (p-1) // d
        if pow(a, m, p) != 1:
            print(-1)
            continue
        if d == 1:
            inv_k = modinv(k, p-1)
            if inv_k is None:
                print(-1)
                continue
            ans = pow(a, inv_k, p)
            print(ans)
            continue
        a_mod = a % p
        if d == 2:
            ans = tonelli_shanks(a_mod, p)
            if ans is None:
                print(-1)
            else:
                k_prime = k // d
                m_prime = (p-1) // d
                inv_k_prime = modinv(k_prime, m_prime)
                if inv_k_prime is None:
                    print(-1)
                else:
                    x = pow(ans, inv_k_prime, p)
                    print(x % p)
            continue
        y = amm(a_mod, d, p)
        if y is None:
            print(-1)
            continue
        k_prime = k // d
        m_prime = (p-1) // d
        inv_k_prime = modinv(k_prime, m_prime)
        if inv_k_prime is None:
            print(-1)
            continue
        x = pow(y, inv_k_prime, p)
        print(x % p)

if __name__ == '__main__':
    solve()
0