結果
| 問題 |
No.981 一般冪乗根
|
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-31 17:54:02 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 4,011 bytes |
| コンパイル時間 | 147 ms |
| コンパイル使用メモリ | 82,360 KB |
| 実行使用メモリ | 75,020 KB |
| 最終ジャッジ日時 | 2025-03-31 17:56:56 |
| 合計ジャッジ時間 | 41,886 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 3 WA * 32 TLE * 1 -- * 8 |
ソースコード
import sys
import math
import random
def input():
return sys.stdin.read()
def pow_mod(a, b, p):
return pow(a, b, p)
def gcd(a, b):
return math.gcd(a, b)
def modinv(a, m):
g, x, y = extended_gcd(a, m)
if g != 1:
return None
else:
return x % m
def extended_gcd(a, b):
if a == 0:
return (b, 0, 1)
else:
g, y, x = extended_gcd(b % a, a)
return (g, x - (b // a) * y, y)
def tonelli_shanks(n, p):
assert pow_mod(n, (p - 1) // 2, p) == 1, "n is not a square (mod p)"
if p % 4 == 3:
x = pow_mod(n, (p + 1) // 4, p)
return x
Q = p - 1
S = 0
while Q % 2 == 0:
Q //= 2
S += 1
z = 2
while pow_mod(z, (p - 1) // 2, p) != p - 1:
z += 1
c = pow_mod(z, Q, p)
x = pow_mod(n, (Q + 1) // 2, p)
t = pow_mod(n, Q, p)
m = S
while t != 1:
tmp = t
i = 0
while tmp != 1 and i < m:
tmp = pow_mod(tmp, 2, p)
i += 1
if i == m:
return None
b = pow_mod(c, 1 << (m - i - 1), p)
x = (x * b) % p
t = (t * b * b) % p
c = (b * b) % p
m = i
return x
def find_square_root(a, p):
a %= p
if a == 0:
return 0
if p == 2:
return a
if pow_mod(a, (p - 1) // 2, p) != 1:
return None
return tonelli_shanks(a, p)
def nth_root_mod(a, e, p):
if a == 0:
return 0
g = gcd(e, p - 1)
m = (p - 1) // g
if pow_mod(a, m, p) != 1:
return None
if g == 1:
inv_e = modinv(e, p - 1)
if inv_e is None:
return None
return pow_mod(a, inv_e, p)
if g == 2:
root = find_square_root(a, p)
return root
def find_prime_factors(n):
i = 2
factors = {}
while i * i <= n:
while n % i == 0:
factors[i] = factors.get(i, 0) + 1
n //= i
i += 1
if n > 1:
factors[n] = 1
return factors
factors = find_prime_factors(g)
current_a = a
current_e = 1
x = 1
for q in factors:
exp = factors[q]
q_exp = q ** exp
new_e = e // q_exp
new_g = gcd(new_e, p - 1)
if new_g != 1:
return None
inv_new_e = modinv(new_e, p - 1)
if inv_new_e is None:
return None
part_root = pow_mod(current_a, inv_new_e, p)
x = (x * part_root) % p
current_e *= q_exp
return x
def solve():
data = sys.stdin.read().split()
T = int(data[0])
idx = 1
for _ in range(T):
p = int(data[idx])
k = int(data[idx+1])
a = int(data[idx+2])
idx +=3
if a == 0:
print(0)
continue
if p == 2:
print(a)
continue
k_mod = k % (p-1)
if k_mod == 0:
if a % p == 1:
print(1)
else:
print(-1)
continue
g = gcd(k_mod, p-1)
m = (p-1) // g
if pow_mod(a, m, p) != 1:
print(-1)
continue
a_mod = a % p
k_prime = k_mod // g
inv_k_prime = modinv(k_prime, m)
if inv_k_prime is None:
print(-1)
continue
if g == 1:
res = pow_mod(a_mod, inv_k_prime, p)
print(res)
continue
candidate = None
if g == 2:
candidate = find_square_root(a_mod, p)
if candidate is None:
print(-1)
continue
res = pow_mod(candidate, inv_k_prime, p)
print(res % p)
continue
roots = []
try_root = nth_root_mod(a_mod, g, p)
if try_root is None:
print(-1)
continue
else:
res = pow_mod(try_root, inv_k_prime, p)
print(res)
solve()
lam6er