結果
| 問題 |
No.1112 冥界の音楽
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2023-06-14 02:42:10 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 19,486 bytes |
| コンパイル時間 | 663 ms |
| コンパイル使用メモリ | 82,696 KB |
| 実行使用メモリ | 80,184 KB |
| 最終ジャッジ日時 | 2024-06-22 14:37:53 |
| 合計ジャッジ時間 | 4,873 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 12 WA * 22 |
ソースコード
from copy import deepcopy
from random import randint
from __pypy__.builders import StringBuilder
import sys
from os import read as os_read, write as os_write
from atexit import register as atexist_register
from typing import Generic, Iterator, List, Tuple, Dict, Iterable, Sequence, Callable, Union, Optional, TypeVar
T = TypeVar('T')
Graph = List[List[int]]
Poly = List[int]
Vector = List[int]
Matrix = List[List[int]]
Func10 = Callable[[int], None]
Func20 = Callable[[int, int], None]
Func11 = Callable[[int], int]
Func21 = Callable[[int, int], int]
Func31 = Callable[[int, int, int], int]
class Fastio:
ibuf = bytes()
pil = pir = 0
sb = StringBuilder()
def load(self):
self.ibuf = self.ibuf[self.pil:]
self.ibuf += os_read(0, 131072)
self.pil = 0; self.pir = len(self.ibuf)
def flush_atexit(self): os_write(1, self.sb.build().encode())
def flush(self):
os_write(1, self.sb.build().encode())
self.sb = StringBuilder()
def fastin(self):
if self.pir - self.pil < 64: self.load()
minus = x = 0
while self.ibuf[self.pil] < 45: self.pil += 1
if self.ibuf[self.pil] == 45: minus = 1; self.pil += 1
while self.ibuf[self.pil] >= 48:
x = x * 10 + (self.ibuf[self.pil] & 15)
self.pil += 1
if minus: return -x
return x
def fastin_string(self):
if self.pir - self.pil < 64: self.load()
while self.ibuf[self.pil] <= 32: self.pil += 1
res = bytearray()
while self.ibuf[self.pil] > 32:
if self.pir - self.pil < 64: self.load()
res.append(self.ibuf[self.pil])
self.pil += 1
return res
def fastout(self, x): self.sb.append(str(x))
def fastoutln(self, x): self.sb.append(str(x)); self.sb.append('\n')
fastio = Fastio()
rd = fastio.fastin; rds = fastio.fastin_string; wt = fastio.fastout; wtn = fastio.fastoutln; flush = fastio.flush
atexist_register(fastio.flush_atexit)
sys.stdin = None; sys.stdout = None
def rdl(n): return [rd() for _ in range(n)]
def wtnl(l): wtn(' '.join(map(str, l)))
def wtn_yes(): wtn("Yes")
def wtn_no(): wtn("No")
def modinv(a: int, m: int) -> int:
'''return x s.t. x == a^(-1) (mod m)'''
b = m; u = 1; v = 0
while b:
t = a // b
a, b = b, a - t * b
u, v = v, u - t * v
u %= m
return u
# https://nyaannyaan.github.io/library/fps/berlekamp-massey.hpp
def berlekamp_massey(s: Vector, mod: int) -> Vector:
N = len(s)
b = [1]
c = [1]
y = 1
for ed in range(1, N + 1):
l = len(c)
m = len(b)
x = 0
for i, a in enumerate(c): x += a * s[ed - l + i]
x %= mod
b.append(0)
m += 1
if x == 0: continue
freq = x * modinv(y, mod) % mod
if l < m:
tmp = c[:]
c[:0] = [0] * (m - l)
for i in range(m): c[m - 1 - i] = (c[m - 1 - i] - freq * b[m - 1 - i]) % mod
b = tmp
y = x
else:
for i in range(m): c[l - 1 - i] = (c[l - 1 - i] - freq * b[m - 1 - i]) % mod
c.reverse()
return c
MOD = 998244353
_IMAG = 911660635
_IIMAG = 86583718
_rate2 = (0, 911660635, 509520358, 369330050, 332049552, 983190778, 123842337, 238493703, 975955924, 603855026, 856644456, 131300601, 842657263, 730768835, 942482514, 806263778, 151565301, 510815449, 503497456, 743006876, 741047443, 56250497, 867605899, 0)
_rate3 = (0, 372528824, 337190230, 454590761, 816400692, 578227951, 180142363, 83780245, 6597683, 70046822, 623238099, 183021267, 402682409, 631680428, 344509872, 689220186, 365017329, 774342554, 729444058, 102986190, 128751033, 395565204, 0)
_irate3 = (0, 509520358, 929031873, 170256584, 839780419, 282974284, 395914482, 444904435, 72135471, 638914820, 66769500, 771127074, 985925487, 262319669, 262341272, 625870173, 768022760, 859816005, 914661783, 430819711, 272774365, 530924681, 0)
class NTT:
@staticmethod
def _fft(a: Vector) -> None:
n = len(a)
h = (n - 1).bit_length()
le = 0
for le in range(0, h - 1, 2):
p = 1 << (h - le - 2)
rot = 1
for s in range(1 << le):
rot2 = rot * rot % MOD
rot3 = rot2 * rot % MOD
offset = s << (h - le)
for i in range(p):
a0 = a[i + offset]
a1 = a[i + offset + p] * rot
a2 = a[i + offset + p * 2] * rot2
a3 = a[i + offset + p * 3] * rot3
a1na3imag = (a1 - a3) % MOD * _IMAG
a[i + offset] = (a0 + a2 + a1 + a3) % MOD
a[i + offset + p] = (a0 + a2 - a1 - a3) % MOD
a[i + offset + p * 2] = (a0 - a2 + a1na3imag) % MOD
a[i + offset + p * 3] = (a0 - a2 - a1na3imag) % MOD
rot = rot * _rate3[(~s & -~s).bit_length()] % MOD
if h - le & 1:
rot = 1
for s in range(1 << (h - 1)):
offset = s << 1
l = a[offset]
r = a[offset + 1] * rot
a[offset] = (l + r) % MOD
a[offset + 1] = (l - r) % MOD
rot = rot * _rate2[(~s & -~s).bit_length()] % MOD
@staticmethod
def _ifft(a: Vector) -> None:
n = len(a)
h = (n - 1).bit_length()
le = h
for le in range(h, 1, -2):
p = 1 << (h - le)
irot = 1
for s in range(1 << (le - 2)):
irot2 = irot * irot % MOD
irot3 = irot2 * irot % MOD
offset = s << (h - le + 2)
for i in range(p):
a0 = a[i + offset]
a1 = a[i + offset + p]
a2 = a[i + offset + p * 2]
a3 = a[i + offset + p * 3]
a2na3iimag = (a2 - a3) * _IIMAG % MOD
a[i + offset] = (a0 + a1 + a2 + a3) % MOD
a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % MOD
a[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % MOD
a[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % MOD
irot = irot * _irate3[(~s & -~s).bit_length()] % MOD
if le & 1:
p = 1 << (h - 1)
for i in range(p):
l = a[i]
r = a[i + p]
a[i] = l + r if l + r < MOD else l + r - MOD
a[i + p] = l - r if l - r >= 0 else l - r + MOD
@classmethod
def ntt(cls, a: Vector) -> None:
if len(a) <= 1: return
cls._fft(a)
@classmethod
def intt(cls, a:Vector) -> None:
if len(a) <= 1: return
cls._ifft(a)
iv = modinv(len(a), MOD)
for i, x in enumerate(a): a[i] = x * iv % MOD
@classmethod
def multiply(cls, s: Vector, t: Vector) -> Vector:
n, m = len(s), len(t)
l = n + m - 1
if min(n, m) <= 60:
a = [0] * l
for i, x in enumerate(s):
for j, y in enumerate(t):
a[i + j] += x * y
return [x % MOD for x in a]
z = 1 << (l - 1).bit_length()
a = s + [0] * (z - n)
b = t + [0] * (z - m)
cls._fft(a)
cls._fft(b)
for i, x in enumerate(b): a[i] = a[i] * x % MOD
cls._ifft(a)
a[l:] = []
iz = modinv(z, MOD)
return [x * iz % MOD for x in a]
@classmethod
def pow2(cls, s: Vector) -> Vector:
n = len(s)
l = (n << 1) - 1
if n <= 60:
a = [0] * l
for i, x in enumerate(s):
for j, y in enumerate(s):
a[i + j] += x * y
return [x % MOD for x in a]
z = 1 << (l - 1).bit_length()
a = s + [0] * (z - n)
cls._fft(a)
for i, x in enumerate(a): a[i] = x * x % MOD
cls._ifft(a)
a[l:] = []
iz = modinv(z, MOD)
return [x * iz % MOD for x in a]
@classmethod
def ntt_doubling(cls, a: Vector) -> None:
M = len(a)
b = a[:]
cls.intt(b)
r = 1
zeta = pow(3, (MOD - 1) // (M << 1), MOD)
for i, x in enumerate(b):
b[i] = x * r % MOD
r = r * zeta % MOD
cls.ntt(b)
a += b
# https://nyaannyaan.github.io/library/fps/formal-power-series.hpp
# https://nyaannyaan.github.io/library/fps/ntt-friendly-fps.hpp
class FPS:
@staticmethod
def shrink(a: Poly) -> None:
'''remove high degree coef == 0'''
while a and not a[-1]: a.pop()
@staticmethod
def resize(a: Poly, length: int, val: int=0) -> None:
a[length:] = []
a[len(a):] = [val] * (length - len(a))
@staticmethod
def add(l: Poly, r: Union[Poly, int]) -> Poly:
'''l += r'''
if type(r) is int:
res = l[:]
res[0] = (res[0] + r) % MOD
return res
if type(r) is list:
if len(l) < len(r):
res = r[::]
for i, x in enumerate(l): res[i] += x
else:
res = l[::]
for i, x in enumerate(r): res[i] += x
return [x % MOD for x in res]
raise TypeError()
@classmethod
def sub(cls, l: Poly, r: Union[Poly, int]) -> Poly:
'''l -= r'''
if type(r) is int: return cls.add(l, -r)
if type(r) is list: return cls.add(l, cls.neg(r))
raise TypeError()
@staticmethod
def neg(a: Poly) -> Poly:
'''a *= -1'''
return [MOD - x if x else 0 for x in a]
@staticmethod
def mul(l: Poly, r: Union[Poly, int]) -> Poly:
'''
if r is int: l *= r
if r is Polynomial: convolve l and r
'''
if type(r) is int: return [x * r % MOD for x in l]
if type(r) is list:
if not l or not r: return []
return NTT.multiply(l, r)
raise TypeError()
@staticmethod
def matmul(l: Poly, r: Poly) -> Poly:
'not verified'
return [x * r[i] % MOD for i, x in enumerate(l)]
@classmethod
def div(cls, l: Poly, r: Poly) -> Poly:
'''return: quo s.t. l = r*quo + rem'''
if len(l) < len(r): return []
n = len(l) - len(r) + 1
if len(r) > 64:
return NTT.multiply(l[::-1][:n], cls.inv(r[::-1], n))[:n][::-1]
f, g = l[::], r[::]
cnt = 0
while g and not g[-1]:
g.pop()
cnt += 1
coef = modinv(g[-1], MOD)
g = cls.mul(g, coef)
deg = len(f) - len(g) + 1
gs = len(g)
quo = [0] * deg
for i in range(deg)[::-1]:
quo[i] = x = f[i + gs - 1] % MOD
for j, y in enumerate(g):
f[i + j] -= x * y
return cls.mul(quo, coef) + [0] * cnt
@classmethod
def modulo(cls, l: Poly, r: Poly) -> Poly:
'''return: rem s.t. l = r*quo + rem'''
res = cls.sub(l, NTT.multiply(cls.div(l, r), r))
cls.shrink(res)
return res
@classmethod
def divmod(cls, l: Poly, r: Poly) -> Tuple[Poly, Poly]:
'''return: quo, rem s.t. l = r*quo + rem'''
quo = cls.div(l, r)
rem = cls.sub(l, NTT.multiply(quo, r))
cls.shrink(rem)
return quo, rem
@staticmethod
def eval(a: Poly, x: int) -> int:
r = 0; w = 1
for v in a:
r += w * v % MOD
w = w * x % MOD
return r % MOD
@staticmethod
def inv(a: Poly, deg: int=-1) -> Poly:
'''return: g s.t. a*g == 1 (mod x**deg)'''
# assert(self[0] != 0)
if deg == -1: deg = len(a)
res = [0] * deg
res[0] = modinv(a[0], MOD)
d = 1
while d < deg:
f = [0] * (d << 1)
tmp = min(len(a), d << 1)
f[:tmp] = a[:tmp]
g = [0] * (d << 1)
g[:d] = res[:d]
NTT.ntt(f)
NTT.ntt(g)
for i, x in enumerate(g): f[i] = f[i] * x % MOD
NTT.intt(f)
f[:d] = [0] * d
NTT.ntt(f)
for i, x in enumerate(g): f[i] = f[i] * x % MOD
NTT.intt(f)
for j in range(d, min(d << 1, deg)):
if f[j]: res[j] = MOD - f[j]
else: res[j] = 0
d <<= 1
return res
@classmethod
def pow(cls, f: Poly, k: int, deg=-1) -> Poly:
'''return: g s.t. g == f**k (mod x**deg)'''
n = len(f)
if deg == -1: deg = n
if k == 0:
if not deg: return []
ret = [0] * deg
ret[0] = 1
return ret
for i, x in enumerate(f):
if x:
rev = modinv(x, MOD)
ret = cls.mul(cls.exp(cls.mul(cls.log(cls.mul(f, rev)[i:], deg), k), deg), pow(x, k, MOD))
ret[:0] = [0] * (i * k)
if len(ret) < deg:
cls.resize(ret, deg)
return ret
return ret[:deg]
if (i + 1) * k >= deg: break
return [0] * deg
@staticmethod
def exp(f: Poly, deg: int=-1) -> Poly:
'''return: g s.t. log(g) == f (mod x ** deg)'''
# assert(not self or self[0] == 0)
if deg == -1: deg = len(f)
inv = [0, 1]
def integral(f: Poly) -> Poly:
n = len(f)
while len(inv) <= n:
j, k = divmod(MOD, len(inv))
inv.append((-inv[k] * j) % MOD)
return [0] + [x * inv[i + 1] % MOD for i, x in enumerate(f)]
def diff(f: Poly) -> Poly:
return [x * i % MOD for i, x in enumerate(f) if i]
b: Poly = [1, (f[1] if 1 < len(f) else 0)]
c: Poly = [1]
z1: Poly= []
z2: Poly = [1, 1]
m = 2
while m < deg:
y = b + [0] * m
NTT.ntt(y)
z1 = z2
z = [y[i] * p % MOD for i, p in enumerate(z1)]
NTT.intt(z)
z[:m >> 1] = [0] * (m >> 1)
NTT.ntt(z)
for i, p in enumerate(z1): z[i] = z[i] * (-p) % MOD
NTT.intt(z)
c[m >> 1:] = z[m >> 1:]
z2 = c + [0] * m
NTT.ntt(z2)
tmp = min(len(f), m)
x = f[:tmp] + [0] * (m - tmp)
x = diff(x)
x.append(0)
NTT.ntt(x)
for i, p in enumerate(x): x[i] = y[i] * p % MOD
NTT.intt(x)
for i, p in enumerate(b):
if not i: continue
x[i - 1] -= p * i % MOD
x += [0] * m
for i in range(m - 1): x[m + i], x[i] = x[i], 0
NTT.ntt(x)
for i, p in enumerate(z2): x[i] = x[i] * p % MOD
NTT.intt(x)
x.pop()
x = integral(x)
x[:m] = [0] * m
for i in range(m, min(len(f), m << 1)): x[i] += f[i]
NTT.ntt(x)
for i, p in enumerate(y): x[i] = x[i] * p % MOD
NTT.intt(x)
b[m:] = x[m:]
m <<= 1
return b[:deg]
@classmethod
def log(cls, f: Poly, deg=-1) -> Poly:
'''return: g s.t. g == log(f) (mod x**deg)'''
# assert(a[0] == 1)
if deg == -1: deg = len(f)
return cls.integral(cls.mul(cls.diff(f), cls.inv(f, deg))[:deg - 1])
@staticmethod
def integral(f: Poly) -> Poly:
n = len(f)
res = [0] * (n + 1)
if n: res[1] = 1
for i in range(2, n + 1):
j, k = divmod(MOD, i)
res[i] = (-res[k] * j) % MOD
for i, x in enumerate(f): res[i + 1] = res[i + 1] * x % MOD
return res
@staticmethod
def diff(f: Poly) -> Poly:
'''return: dfdx'''
return [i * x % MOD for i, x in enumerate(f) if i]
# https://nyaannyaan.github.io/library/fps/mod-pow.hpp
def mod_pow(k: int, base: Poly, d: Poly) -> Poly:
assert(d)
inv = FPS.inv(d[::-1])
def quo(poly: Poly) -> Poly:
if len(poly) < len(d): return []
n = len(poly) - len(d) + 1
return NTT.multiply(poly[:len(poly) - n - 1:-1], inv[:n])[n - 1::-1]
res = [1]
b = base[:]
while k:
if k & 1:
res = NTT.multiply(res, b)
res = FPS.sub(res, NTT.multiply(quo(res), d))
FPS.shrink(res)
b = NTT.pow2(b)
b = FPS.sub(b, NTT.multiply(quo(b), d))
FPS.shrink(b)
k >>= 1
# assert(len(b) + 1 <= len(d))
# assert(len(res) + 1 <= len(d))
return res
# https://nyaannyaan.github.io/library/matrix/black-box-linear-algebra.hpp
def inner_product(a: Poly, b: Poly) -> int:
res = 0
n = len(a)
assert(n == len(b))
for i in range(n): res += a[i] * b[i] % MOD
return res % MOD
def random_poly(n: int) -> Poly:
return [randint(0, MOD - 1) for _ in range(n)]
class ModMatrix:
def __init__(self, n: int) -> None:
self.mat = [[0] * n for _ in range(n)]
def add(self, i: int, j: int, x: int) -> None:
self.mat[i][j] += x
def __mul__(self, r: Poly) -> Poly:
assert(len(self.mat) == len(r))
return [sum(matij * r[j] % MOD for j, matij in enumerate(mati)) % MOD for mati in self.mat]
def apply(self, i: int, r: int) -> None:
mati = self.mat[i]
for j, matij in enumerate(mati):
mati[j] = matij * r % MOD
class SparseMatrix:
def __init__(self, n: int) -> None:
self.mat: List[List[int]] = [[] for _ in range(n)]
def add(self, i: int, j: int, x: int) -> None:
self.mat[i].append(j << 30 | x)
def __mul__(self, r: Poly) -> Poly:
assert(len(self.mat) == len(r))
return [sum((jx & 0x3fffffff) * r[jx >> 30] % MOD for jx in mati) % MOD for mati in self.mat]
def apply(self, i: int, r: int) -> None:
for idx, jx in enumerate(self.mat[i]):
self.mat[i][idx] = (jx >> 30) << 30 | ((jx & 0x3fffffff) * r % MOD)
def vector_minpoly(b: List[Poly]) -> Poly:
assert(b)
n = len(b); m = len(b[0])
u = random_poly(m)
a = [0] * n
for i, bi in enumerate(b): a[i] = inner_product(bi, u)
return berlekamp_massey(a, MOD)
def mat_minpoly(A: Union[ModMatrix, SparseMatrix]) -> Poly:
n = len(A.mat)
u = random_poly(n)
b: List[Poly] = [0] * (n << 1 | 1)
for i in range(len(b)):
b[i] = u
u = A * u
return vector_minpoly(b)
def fast_pow(A: Union[ModMatrix, SparseMatrix], b: Poly, k: int) -> Poly:
n = len(b)
mp = mat_minpoly(A)
c = mod_pow(k, [0, 1], mp[::-1])
res = [0] * n
for ci in c:
res = FPS.add(res, FPS.mul(b, ci))
b = A * b
return res
def fast_det(A: Union[ModMatrix, SparseMatrix]) -> int:
n = len(A.mat)
assert(n == len(A.mat))
D = random_poly(n)
while 1:
while any([not x for x in D]): D = random_poly(n)
AD = deepcopy(A)
for i, d in enumerate(D): AD.apply(i, d)
mp = mat_minpoly(AD)
if mp[-1] == 0: return 0
if len(mp) != n + 1: continue
det = -mp[-1] if n & 1 else mp[-1]
Ddet = 1
for d in D: Ddet = Ddet * d % MOD
return det * modinv(Ddet, MOD) % MOD
exit(1)
# https://yukicoder.me/problems/no/1112
K, M, N = rd(), rd(), rd()
m = ModMatrix(K * K)
for i in range(M):
p, q, r = rd() - 1, rd() - 1, rd() - 1
m.add(p * K + q, q * K + r, 1)
b = [0] * (K * K)
for i in range(K): b[i * K] = 1
res = fast_pow(m, b, N - 2)
wtn(sum(res[:N]))