結果
| 問題 | No.3231 2×2行列相似判定 〜hard〜 |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2025-08-08 23:15:00 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 7,441 bytes |
| コンパイル時間 | 315 ms |
| コンパイル使用メモリ | 82,568 KB |
| 実行使用メモリ | 78,576 KB |
| 最終ジャッジ日時 | 2025-08-08 23:15:08 |
| 合計ジャッジ時間 | 6,755 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 WA * 1 |
| other | AC * 19 WA * 18 |
ソースコード
import sys
input = lambda :sys.stdin.readline()[:-1]
ni = lambda :int(input())
na = lambda :list(map(int,input().split()))
yes = lambda :print("yes");Yes = lambda :print("Yes");YES = lambda : print("YES")
no = lambda :print("no");No = lambda :print("No");NO = lambda : print("NO")
#######################################################################
def is_prime(n):
'O(logN) miller rabin algorithm'
if n == 2: return 1
if n == 1 or not n&1: return 0
#miller_rabin
if n < 1<<30: test_numbers = [2, 7, 61]
else: test_numbers = [2, 325, 9375, 28178, 450775, 9780504, 1795265022]
d = n - 1
while ~d&1: d>>=1
for a in test_numbers:
if n <= a: break
t = d
y = pow(a, t, n)
while t != n-1 and y != 1 and y != n-1:
y = pow(y, 2, n)
t <<= 1
if y != n-1 and not t&1: return 0
return 1
from functools import reduce
def gcd2(a: int, b: int) -> int:
while a: a, b = b % a, a
return b
def gcd(*numbers) -> int: return reduce(gcd2, numbers)
def lcm2(x: int, y: int) -> int: return (x * y) // gcd2(x, y)
def lcm(*integers) -> int: return reduce(lcm2, integers)
def extgcd(a: int, b: int):
'Tuple[gcd(a, b), x, y] s.t. ax + by = gcd(a, b) (Extended Euclidean algorithm)'
if b:
d, y, x = extgcd(b, a % b)
y -= (a // b) * x
return d, x, y
return a, 1, 0
def crt(V):
'V: [(X_i, Y_i), ...]: X_i (mod Y_i)'
x = 0; d = 1
for X, Y in V:
g, a, b = extgcd(d, Y)
x, d = (Y*b*x + d*a*X) // g, d*(Y // g)
x %= d
return x, d
from random import randrange
def pollard_rho(n):
b = n.bit_length()-1
b = (b>>2)<<2
m = int(2**(b/8))<<1
while True:
c = randrange(1, n)
f = lambda a: (pow(a,2,n)+c)%n
y = 0
g = q = r = 1
while g == 1:
x = y
for _ in range(r): y = f(y)
k = 0
while k < r and g == 1:
ys = y
for _ in range(min(m, r-k)):
y = f(y)
q = q*abs(x-y)%n
g = gcd2(q, n)
k += m
r <<= 1
if g == n:
g = 1
y = ys
while g == 1:
y = f(y)
g = gcd2(abs(x-y), n)
if g == n: continue
if is_prime(g): return g
elif is_prime(n//g): return n//g
else: n = g
def factorize(n):
'O(N**0.25) pollard rho algorithm'
res = {}
for p in range(2,1000):
if p*p > n: break
if n%p: continue
s = 0
while n%p == 0:
n //= p
s += 1
res[p] = s
while not is_prime(n) and n > 1:
p = pollard_rho(n)
s = 0
while n%p == 0:
n //= p
s += 1
res[p] = s
if n > 1: res[n] = 1
return res
from collections import defaultdict
class _Memo:
def __init__(self, g: int, s: int, period: int, mod: int):
self.lg = min(s, period).bit_length() - 1
self.size = size = 1 << self.lg
self.mask = mask = size - 1
self.period = period
self.mod = mod
self.vs = vs = [[0, 0] for _ in range(size)]
self.os = os = [0] * (size + 1)
x = 1
for i in range(size):
os[x & mask] += 1
x = x * g % mod
for i in range(1, size): os[i] += os[i - 1]
x = 1
for i in range(size):
tmp = os[x & mask] - 1
vs[tmp] = [x, i]
os[x & mask] = tmp
x = x * g % mod
self.gpow = x
os[size] = size
def find(self, x: int) -> int:
size = self.size; period = self.period; mod = self.mod; gpow = self.gpow; mask = self.mask
os = self.os; vs = self.vs
t = 0
while t < period:
m = x & mask
i = os[m]
while i < os[m + 1]:
if x == vs[i][0]:
res = vs[i][1] - t
return res + period if res < 0 else res
i += 1
t += size
x = x * gpow % mod
def pe_root(c: int, pi: int, ei: int, p: int) -> int:
s = p - 1; t = 0
while not s % pi:
s //= pi
t += 1
pe = pow(pi, ei)
u = inv(pe - s % pe, pe)
mc = c % p
z = pow(mc, (s * u + 1) // pe, p)
zpe = pow(mc, s * u, p)
if zpe == 1: return z
ptm1 = pow(pi, t - 1)
v = 2
vs = pow(v, s, p)
v = 3
while pow(vs, ptm1, p) == 1:
vs = pow(v, s, p)
v += 1
vspe = pow(vs, pe, p)
vs_e = ei
base = vspe
for _ in range(t - ei - 1): base = pow(base, pi, p)
memo = _Memo(base, int((t - ei) ** 0.5 * pi ** 0.5) + 1, pi, p)
while zpe != 1:
tmp = zpe
td = 0
while tmp != 1:
td += 1
tmp = pow(tmp, pi, p)
e = t - td
while vs_e != e:
vs = pow(vs, pi, p)
vspe = pow(vspe, pi, p)
vs_e += 1
base_zpe = pow(zpe, p - 2, p)
for _ in range(td - 1): base_zpe = pow(base_zpe, pi, p)
bsgs = memo.find(base_zpe)
z = z * pow(vs, bsgs, p) % p
zpe = zpe * pow(vspe, bsgs, p) % p
return z
def _kth_root(a: int, k: int, p: int) -> int:
a %= p
if k == 0: return a if a == 1 else -1
if a <= 1 or k <= 1: return a
g = gcd2(p - 1, k)
if pow(a, (p - 1) // g, p) != 1: return -1
a = pow(a, inv(k // g, (p - 1) // g), p)
fac = defaultdict(int)
for prime, cnt in factorize(g).items(): fac[prime] += cnt
for k, v in fac.items(): a = pe_root(a, k, v, p)
return a
def kth_root(a: int, k: int, p: int) -> int:
"""
X s.t. pow(X, k) == a (mod p)
"""
return _kth_root(a, k, p)
def inv(a: int, p: int) -> int:
b = p; x = 1; y = 0
while a:
q = b // a
a, b = b % a, a
x, y = y - q * x, x
return y + p if y < 0 else y
def g(a, b): # x^2 + ax + b = 0 (mod p)
# (x + a/2)^2 = (a/2)^2 - b
res = kth_root((a * a * pow(4, mod-2, mod) - b) % mod, 2, mod)
return res
mod = 10 ** 9 + 7
def rank(A):
if A == [[0, 0], [0, 0]]:
r = 0
elif (A[0][0] * A[1][1] - A[0][1] * A[1][0]) % mod == 0:
r = 1
else:
r = 2
return r
def f(A):
return g(-(A[0][0] + A[1][1]), (A[0][1] + A[1][0]) % mod)
def g(A):
return (A[0][0] + A[1][1]) % mod, (A[0][0] * A[1][1] - A[0][1] * A[1][0]) % mod
def mat_mul(A, B):
return [[(A[0][0] * B[0][0] + A[0][1] * B[1][0]) % mod, (A[0][0] * B[0][1] + A[0][1] * B[1][1]) % mod],
[(A[1][0] * B[0][0] + A[1][1] * B[1][0]) % mod, (A[1][0] * B[0][1] + A[1][1] * B[1][1]) % mod]]
def inv_mat(A):
INV = pow(A[0][0] * A[1][1] - A[0][1] * A[1][0], mod - 2, mod)
return [[A[1][1] * INV % mod, -A[0][1] * INV % mod], [-A[1][0] * INV % mod, A[0][0] * INV % mod]]
A = [na(), na()]
B = [na(), na()]
ans = 1
from random import randint
SA = {}
SB = {}
for _ in range(100000):
# print(f(A), f(B))
p = [[randint(0, mod -1) for _ in range(2)] for _ in range(2)]
if rank(p) == 2:
continue
A = mat_mul(inv_mat(p), mat_mul(A, p))
B = mat_mul(inv_mat(p), mat_mul(B, p))
SA.add(f(A))
SB.add(f(B))
# print(SA, SB)
if rank(A) != rank(B):
No()
elif SA == SB == {}:
if A == [[1, 0], [0, 2]] and [[1, 0], [0, 8]] == B:
No()
else:
No()
elif SA == SB:
Yes()
else:
No()