結果
問題 |
No.3231 2×2行列相似判定 〜hard〜
|
ユーザー |
|
提出日時 | 2025-08-08 23:29:07 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 7,497 bytes |
コンパイル時間 | 327 ms |
コンパイル使用メモリ | 82,568 KB |
実行使用メモリ | 78,752 KB |
最終ジャッジ日時 | 2025-08-08 23:29:14 |
合計ジャッジ時間 | 6,438 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 31 WA * 6 |
ソースコード
import sys input = lambda :sys.stdin.readline()[:-1] ni = lambda :int(input()) na = lambda :list(map(int,input().split())) yes = lambda :print("yes");Yes = lambda :print("Yes");YES = lambda : print("YES") no = lambda :print("no");No = lambda :print("No");NO = lambda : print("NO") ####################################################################### def is_prime(n): 'O(logN) miller rabin algorithm' if n == 2: return 1 if n == 1 or not n&1: return 0 #miller_rabin if n < 1<<30: test_numbers = [2, 7, 61] else: test_numbers = [2, 325, 9375, 28178, 450775, 9780504, 1795265022] d = n - 1 while ~d&1: d>>=1 for a in test_numbers: if n <= a: break t = d y = pow(a, t, n) while t != n-1 and y != 1 and y != n-1: y = pow(y, 2, n) t <<= 1 if y != n-1 and not t&1: return 0 return 1 from functools import reduce def gcd2(a: int, b: int) -> int: while a: a, b = b % a, a return b def gcd(*numbers) -> int: return reduce(gcd2, numbers) def lcm2(x: int, y: int) -> int: return (x * y) // gcd2(x, y) def lcm(*integers) -> int: return reduce(lcm2, integers) def extgcd(a: int, b: int): 'Tuple[gcd(a, b), x, y] s.t. ax + by = gcd(a, b) (Extended Euclidean algorithm)' if b: d, y, x = extgcd(b, a % b) y -= (a // b) * x return d, x, y return a, 1, 0 def crt(V): 'V: [(X_i, Y_i), ...]: X_i (mod Y_i)' x = 0; d = 1 for X, Y in V: g, a, b = extgcd(d, Y) x, d = (Y*b*x + d*a*X) // g, d*(Y // g) x %= d return x, d from random import randrange def pollard_rho(n): b = n.bit_length()-1 b = (b>>2)<<2 m = int(2**(b/8))<<1 while True: c = randrange(1, n) f = lambda a: (pow(a,2,n)+c)%n y = 0 g = q = r = 1 while g == 1: x = y for _ in range(r): y = f(y) k = 0 while k < r and g == 1: ys = y for _ in range(min(m, r-k)): y = f(y) q = q*abs(x-y)%n g = gcd2(q, n) k += m r <<= 1 if g == n: g = 1 y = ys while g == 1: y = f(y) g = gcd2(abs(x-y), n) if g == n: continue if is_prime(g): return g elif is_prime(n//g): return n//g else: n = g def factorize(n): 'O(N**0.25) pollard rho algorithm' res = {} for p in range(2,1000): if p*p > n: break if n%p: continue s = 0 while n%p == 0: n //= p s += 1 res[p] = s while not is_prime(n) and n > 1: p = pollard_rho(n) s = 0 while n%p == 0: n //= p s += 1 res[p] = s if n > 1: res[n] = 1 return res from collections import defaultdict class _Memo: def __init__(self, g: int, s: int, period: int, mod: int): self.lg = min(s, period).bit_length() - 1 self.size = size = 1 << self.lg self.mask = mask = size - 1 self.period = period self.mod = mod self.vs = vs = [[0, 0] for _ in range(size)] self.os = os = [0] * (size + 1) x = 1 for i in range(size): os[x & mask] += 1 x = x * g % mod for i in range(1, size): os[i] += os[i - 1] x = 1 for i in range(size): tmp = os[x & mask] - 1 vs[tmp] = [x, i] os[x & mask] = tmp x = x * g % mod self.gpow = x os[size] = size def find(self, x: int) -> int: size = self.size; period = self.period; mod = self.mod; gpow = self.gpow; mask = self.mask os = self.os; vs = self.vs t = 0 while t < period: m = x & mask i = os[m] while i < os[m + 1]: if x == vs[i][0]: res = vs[i][1] - t return res + period if res < 0 else res i += 1 t += size x = x * gpow % mod def pe_root(c: int, pi: int, ei: int, p: int) -> int: s = p - 1; t = 0 while not s % pi: s //= pi t += 1 pe = pow(pi, ei) u = inv(pe - s % pe, pe) mc = c % p z = pow(mc, (s * u + 1) // pe, p) zpe = pow(mc, s * u, p) if zpe == 1: return z ptm1 = pow(pi, t - 1) v = 2 vs = pow(v, s, p) v = 3 while pow(vs, ptm1, p) == 1: vs = pow(v, s, p) v += 1 vspe = pow(vs, pe, p) vs_e = ei base = vspe for _ in range(t - ei - 1): base = pow(base, pi, p) memo = _Memo(base, int((t - ei) ** 0.5 * pi ** 0.5) + 1, pi, p) while zpe != 1: tmp = zpe td = 0 while tmp != 1: td += 1 tmp = pow(tmp, pi, p) e = t - td while vs_e != e: vs = pow(vs, pi, p) vspe = pow(vspe, pi, p) vs_e += 1 base_zpe = pow(zpe, p - 2, p) for _ in range(td - 1): base_zpe = pow(base_zpe, pi, p) bsgs = memo.find(base_zpe) z = z * pow(vs, bsgs, p) % p zpe = zpe * pow(vspe, bsgs, p) % p return z def _kth_root(a: int, k: int, p: int) -> int: a %= p if k == 0: return a if a == 1 else -1 if a <= 1 or k <= 1: return a g = gcd2(p - 1, k) if pow(a, (p - 1) // g, p) != 1: return -1 a = pow(a, inv(k // g, (p - 1) // g), p) fac = defaultdict(int) for prime, cnt in factorize(g).items(): fac[prime] += cnt for k, v in fac.items(): a = pe_root(a, k, v, p) return a def kth_root(a: int, k: int, p: int) -> int: """ X s.t. pow(X, k) == a (mod p) """ return _kth_root(a, k, p) def inv(a: int, p: int) -> int: b = p; x = 1; y = 0 while a: q = b // a a, b = b % a, a x, y = y - q * x, x return y + p if y < 0 else y def g(a, b): # x^2 + ax + b = 0 (mod p) # (x + a/2)^2 = (a/2)^2 - b A = pow(2, mod - 2, mod) * a % mod res = kth_root((A * A - b) % mod, 2, mod) if res == -1: return -1 return (res - A ) % mod mod = 10 ** 9 + 7 def rank(A): if A == [[0, 0], [0, 0]]: r = 0 elif (A[0][0] * A[1][1] - A[0][1] * A[1][0]) % mod == 0: r = 1 else: r = 2 return r def f(A): return g(-(A[0][0] + A[1][1]), (A[0][0] * A[1][1] - A[0][1] * A[1][0]) % mod) def gg(A): return (A[0][0] + A[1][1]) % mod, (A[0][0] * A[1][1] - A[0][1] * A[1][0]) % mod def mat_mul(A, B): return [[(A[0][0] * B[0][0] + A[0][1] * B[1][0]) % mod, (A[0][0] * B[0][1] + A[0][1] * B[1][1]) % mod], [(A[1][0] * B[0][0] + A[1][1] * B[1][0]) % mod, (A[1][0] * B[0][1] + A[1][1] * B[1][1]) % mod]] def inv_mat(A): INV = pow(A[0][0] * A[1][1] - A[0][1] * A[1][0], mod - 2, mod) return [[A[1][1] * INV % mod, -A[0][1] * INV % mod], [-A[1][0] * INV % mod, A[0][0] * INV % mod]] A = [na(), na()] B = [na(), na()] ans = 1 from random import randint SA = set() SB = set() for _ in range(1000): # print(f(A), f(B)) p = [[randint(0, mod -1) for _ in range(2)] for _ in range(2)] if rank(p) != 2: continue # print(A) A = mat_mul(inv_mat(p), mat_mul(A, p)) B = mat_mul(inv_mat(p), mat_mul(B, p)) SA.add(f(A)) SB.add(f(B)) if -1 in SA: SA.remove(-1) if -1 in SB: SB.remove(-1) # print(SA, SB) if rank(A) != rank(B): No() elif SA == SB: Yes() else: No()