結果

問題 No.1112 冥界の音楽
ユーザー abUmaabUma
提出日時 2023-06-14 02:42:10
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 19,486 bytes
コンパイル時間 663 ms
コンパイル使用メモリ 82,696 KB
実行使用メモリ 80,184 KB
最終ジャッジ日時 2024-06-22 14:37:53
合計ジャッジ時間 4,873 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 85 ms
78,432 KB
testcase_01 AC 72 ms
77,608 KB
testcase_02 AC 72 ms
77,580 KB
testcase_03 WA -
testcase_04 WA -
testcase_05 WA -
testcase_06 WA -
testcase_07 AC 71 ms
77,808 KB
testcase_08 AC 72 ms
78,048 KB
testcase_09 AC 85 ms
78,364 KB
testcase_10 AC 71 ms
78,248 KB
testcase_11 WA -
testcase_12 AC 70 ms
77,580 KB
testcase_13 WA -
testcase_14 WA -
testcase_15 WA -
testcase_16 AC 73 ms
77,720 KB
testcase_17 WA -
testcase_18 WA -
testcase_19 AC 73 ms
77,768 KB
testcase_20 WA -
testcase_21 WA -
testcase_22 WA -
testcase_23 WA -
testcase_24 AC 75 ms
77,588 KB
testcase_25 AC 73 ms
77,668 KB
testcase_26 WA -
testcase_27 WA -
testcase_28 WA -
testcase_29 AC 72 ms
77,792 KB
testcase_30 WA -
testcase_31 WA -
testcase_32 WA -
testcase_33 WA -
testcase_34 AC 73 ms
77,536 KB
testcase_35 AC 76 ms
77,680 KB
testcase_36 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

from copy import deepcopy
from random import randint
from __pypy__.builders import StringBuilder
import sys
from os import read as os_read, write as os_write
from atexit import register as atexist_register
from typing import Generic, Iterator, List, Tuple, Dict, Iterable, Sequence, Callable, Union, Optional, TypeVar
T = TypeVar('T')
Graph = List[List[int]]
Poly = List[int]
Vector = List[int]
Matrix = List[List[int]]
Func10 = Callable[[int], None]
Func20 = Callable[[int, int], None]
Func11 = Callable[[int], int]
Func21 = Callable[[int, int], int]
Func31 = Callable[[int, int, int], int]

class Fastio:
    ibuf = bytes()
    pil = pir = 0
    sb = StringBuilder()
    def load(self):
        self.ibuf = self.ibuf[self.pil:]
        self.ibuf += os_read(0, 131072)
        self.pil = 0; self.pir = len(self.ibuf)
    def flush_atexit(self): os_write(1, self.sb.build().encode())
    def flush(self):
        os_write(1, self.sb.build().encode())
        self.sb = StringBuilder()
    def fastin(self):
        if self.pir - self.pil < 64: self.load()
        minus = x = 0
        while self.ibuf[self.pil] < 45: self.pil += 1
        if self.ibuf[self.pil] == 45: minus = 1; self.pil += 1
        while self.ibuf[self.pil] >= 48:
            x = x * 10 + (self.ibuf[self.pil] & 15)
            self.pil += 1
        if minus: return -x
        return x
    def fastin_string(self):
        if self.pir - self.pil < 64: self.load()
        while self.ibuf[self.pil] <= 32: self.pil += 1
        res = bytearray()
        while self.ibuf[self.pil] > 32:
            if self.pir - self.pil < 64: self.load()
            res.append(self.ibuf[self.pil])
            self.pil += 1
        return res
    def fastout(self, x): self.sb.append(str(x))
    def fastoutln(self, x): self.sb.append(str(x)); self.sb.append('\n')
fastio = Fastio()
rd = fastio.fastin; rds = fastio.fastin_string; wt = fastio.fastout; wtn = fastio.fastoutln; flush = fastio.flush
atexist_register(fastio.flush_atexit)
sys.stdin = None; sys.stdout = None
def rdl(n): return [rd() for _ in range(n)]
def wtnl(l): wtn(' '.join(map(str, l)))
def wtn_yes(): wtn("Yes")
def wtn_no(): wtn("No")

def modinv(a: int, m: int) -> int:
    '''return x s.t. x == a^(-1) (mod m)'''
    b = m; u = 1; v = 0
    while b:
        t = a // b
        a, b = b, a - t * b
        u, v = v, u - t * v
    u %= m
    return u

# https://nyaannyaan.github.io/library/fps/berlekamp-massey.hpp
def berlekamp_massey(s: Vector, mod: int) -> Vector:
    N = len(s)
    b = [1]
    c = [1]
    y = 1
    for ed in range(1, N + 1):
        l = len(c)
        m = len(b)
        x = 0
        for i, a in enumerate(c): x += a * s[ed - l + i]
        x %= mod
        b.append(0)
        m += 1
        if x == 0: continue
        freq = x * modinv(y, mod) % mod
        if l < m:
            tmp = c[:]
            c[:0] = [0] * (m - l)
            for i in range(m): c[m - 1 - i] = (c[m - 1 - i] - freq * b[m - 1 - i]) % mod
            b = tmp
            y = x
        else:
            for i in range(m): c[l - 1 - i] = (c[l - 1 - i] - freq * b[m - 1 - i]) % mod
    c.reverse()
    return c

MOD = 998244353
_IMAG = 911660635
_IIMAG = 86583718
_rate2 = (0, 911660635, 509520358, 369330050, 332049552, 983190778, 123842337, 238493703, 975955924, 603855026, 856644456, 131300601, 842657263, 730768835, 942482514, 806263778, 151565301, 510815449, 503497456, 743006876, 741047443, 56250497, 867605899, 0)
_rate3 = (0, 372528824, 337190230, 454590761, 816400692, 578227951, 180142363, 83780245, 6597683, 70046822, 623238099, 183021267, 402682409, 631680428, 344509872, 689220186, 365017329, 774342554, 729444058, 102986190, 128751033, 395565204, 0)
_irate3 = (0, 509520358, 929031873, 170256584, 839780419, 282974284, 395914482, 444904435, 72135471, 638914820, 66769500, 771127074, 985925487, 262319669, 262341272, 625870173, 768022760, 859816005, 914661783, 430819711, 272774365, 530924681, 0)

class NTT:
    @staticmethod
    def _fft(a: Vector) -> None:
        n = len(a)
        h = (n - 1).bit_length()
        le = 0
        for le in range(0, h - 1, 2):
            p = 1 << (h - le - 2)
            rot = 1
            for s in range(1 << le):
                rot2 = rot * rot % MOD
                rot3 = rot2 * rot % MOD
                offset = s << (h - le)
                for i in range(p):
                    a0 = a[i + offset]
                    a1 = a[i + offset + p] * rot
                    a2 = a[i + offset + p * 2] * rot2
                    a3 = a[i + offset + p * 3] * rot3
                    a1na3imag = (a1 - a3) % MOD * _IMAG
                    a[i + offset] = (a0 + a2 + a1 + a3) % MOD
                    a[i + offset + p] = (a0 + a2 - a1 - a3) % MOD
                    a[i + offset + p * 2] = (a0 - a2 + a1na3imag) % MOD
                    a[i + offset + p * 3] = (a0 - a2 - a1na3imag) % MOD
                rot = rot * _rate3[(~s & -~s).bit_length()] % MOD
        if h - le & 1:
            rot = 1
            for s in range(1 << (h - 1)):
                offset = s << 1
                l = a[offset]
                r = a[offset + 1] * rot
                a[offset] = (l + r) % MOD
                a[offset + 1] = (l - r) % MOD
                rot = rot * _rate2[(~s & -~s).bit_length()] % MOD

    @staticmethod
    def _ifft(a: Vector) -> None:
        n = len(a)
        h = (n - 1).bit_length()
        le = h
        for le in range(h, 1, -2):
            p = 1 << (h - le)
            irot = 1
            for s in range(1 << (le - 2)):
                irot2 = irot * irot % MOD
                irot3 = irot2 * irot % MOD
                offset = s << (h - le + 2)
                for i in range(p):
                    a0 = a[i + offset]
                    a1 = a[i + offset + p]
                    a2 = a[i + offset + p * 2]
                    a3 = a[i + offset + p * 3]
                    a2na3iimag = (a2 - a3) * _IIMAG % MOD
                    a[i + offset] = (a0 + a1 + a2 + a3) % MOD
                    a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % MOD
                    a[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % MOD
                    a[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % MOD
                irot = irot * _irate3[(~s & -~s).bit_length()] % MOD
        if le & 1:
            p = 1 << (h - 1)
            for i in range(p):
                l = a[i]
                r = a[i + p]
                a[i] = l + r if l + r < MOD else l + r - MOD
                a[i + p] = l - r if l - r >= 0 else l - r + MOD

    @classmethod
    def ntt(cls, a: Vector) -> None:
        if len(a) <= 1: return
        cls._fft(a)

    @classmethod
    def intt(cls, a:Vector) -> None:
        if len(a) <= 1: return
        cls._ifft(a)
        iv = modinv(len(a), MOD)
        for i, x in enumerate(a): a[i] = x * iv % MOD

    @classmethod
    def multiply(cls, s: Vector, t: Vector) -> Vector:
        n, m = len(s), len(t)
        l = n + m - 1
        if min(n, m) <= 60:
            a = [0] * l
            for i, x in enumerate(s):
                for j, y in enumerate(t):
                    a[i + j] += x * y
            return [x % MOD for x in a]
        z = 1 << (l - 1).bit_length()
        a = s + [0] * (z - n)
        b = t + [0] * (z - m)
        cls._fft(a)
        cls._fft(b)
        for i, x in enumerate(b): a[i] = a[i] * x % MOD
        cls._ifft(a)
        a[l:] = []
        iz = modinv(z, MOD)
        return [x * iz % MOD for x in a]

    @classmethod
    def pow2(cls, s: Vector) -> Vector:
        n = len(s)
        l = (n << 1) - 1
        if n <= 60:
            a = [0] * l
            for i, x in enumerate(s):
                for j, y in enumerate(s):
                    a[i + j] += x * y
            return [x % MOD for x in a]
        z = 1 << (l - 1).bit_length()
        a = s + [0] * (z - n)
        cls._fft(a)
        for i, x in enumerate(a): a[i] = x * x % MOD
        cls._ifft(a)
        a[l:] = []
        iz = modinv(z, MOD)
        return [x * iz % MOD for x in a]

    @classmethod
    def ntt_doubling(cls, a: Vector) -> None:
        M = len(a)
        b = a[:]
        cls.intt(b)
        r = 1
        zeta = pow(3, (MOD - 1) // (M << 1), MOD)
        for i, x in enumerate(b):
            b[i] = x * r % MOD
            r = r * zeta % MOD
        cls.ntt(b)
        a += b

# https://nyaannyaan.github.io/library/fps/formal-power-series.hpp
# https://nyaannyaan.github.io/library/fps/ntt-friendly-fps.hpp
class FPS:
    @staticmethod
    def shrink(a: Poly) -> None:
        '''remove high degree coef == 0'''
        while a and not a[-1]: a.pop()

    @staticmethod
    def resize(a: Poly, length: int, val: int=0) -> None:
        a[length:] = []
        a[len(a):] = [val] * (length - len(a))

    @staticmethod
    def add(l: Poly, r: Union[Poly, int]) -> Poly:
        '''l += r'''
        if type(r) is int:
            res = l[:]
            res[0] = (res[0] + r) % MOD
            return res
        if type(r) is list:
            if len(l) < len(r):
                res = r[::]
                for i, x in enumerate(l): res[i] += x
            else:
                res = l[::]
                for i, x in enumerate(r): res[i] += x
            return [x % MOD for x in res]
        raise TypeError()

    @classmethod
    def sub(cls, l: Poly, r: Union[Poly, int]) -> Poly:
        '''l -= r'''
        if type(r) is int: return cls.add(l, -r)
        if type(r) is list: return cls.add(l, cls.neg(r))
        raise TypeError()

    @staticmethod
    def neg(a: Poly) -> Poly:
        '''a *= -1'''
        return [MOD - x if x else 0 for x in a]

    @staticmethod
    def mul(l: Poly, r: Union[Poly, int]) -> Poly:
        '''
        if r is int: l *= r
        if r is Polynomial: convolve l and r
        '''
        if type(r) is int: return [x * r % MOD for x in l]
        if type(r) is list:
            if not l or not r: return []
            return NTT.multiply(l, r)
        raise TypeError()

    @staticmethod
    def matmul(l: Poly, r: Poly) -> Poly:
        'not verified'
        return [x * r[i] % MOD for i, x in enumerate(l)]

    @classmethod
    def div(cls, l: Poly, r: Poly) -> Poly:
        '''return: quo s.t. l = r*quo + rem'''
        if len(l) < len(r): return []
        n = len(l) - len(r) + 1
        if len(r) > 64:
            return NTT.multiply(l[::-1][:n], cls.inv(r[::-1], n))[:n][::-1]
        f, g = l[::], r[::]
        cnt = 0
        while g and not g[-1]:
            g.pop()
            cnt += 1
        coef = modinv(g[-1], MOD)
        g = cls.mul(g, coef)
        deg = len(f) - len(g) + 1
        gs = len(g)
        quo = [0] * deg
        for i in range(deg)[::-1]:
            quo[i] = x = f[i + gs - 1] % MOD
            for j, y in enumerate(g):
                f[i + j] -= x * y
        return cls.mul(quo, coef) + [0] * cnt

    @classmethod
    def modulo(cls, l: Poly, r: Poly) -> Poly:
        '''return: rem s.t. l = r*quo + rem'''
        res = cls.sub(l, NTT.multiply(cls.div(l, r), r))
        cls.shrink(res)
        return res

    @classmethod
    def divmod(cls, l: Poly, r: Poly) -> Tuple[Poly, Poly]:
        '''return: quo, rem s.t. l = r*quo + rem'''
        quo = cls.div(l, r)
        rem = cls.sub(l, NTT.multiply(quo, r))
        cls.shrink(rem)
        return quo, rem

    @staticmethod
    def eval(a: Poly, x: int) -> int:
        r = 0; w = 1
        for v in a:
            r += w * v % MOD
            w = w * x % MOD
        return r % MOD

    @staticmethod
    def inv(a: Poly, deg: int=-1) -> Poly:
        '''return: g s.t. a*g == 1 (mod x**deg)'''
        # assert(self[0] != 0)
        if deg == -1: deg = len(a)
        res = [0] * deg
        res[0] = modinv(a[0], MOD)
        d = 1
        while d < deg:
            f = [0] * (d << 1)
            tmp = min(len(a), d << 1)
            f[:tmp] = a[:tmp]
            g = [0] * (d << 1)
            g[:d] = res[:d]
            NTT.ntt(f)
            NTT.ntt(g)
            for i, x in enumerate(g): f[i] = f[i] * x % MOD
            NTT.intt(f)
            f[:d] = [0] * d
            NTT.ntt(f)
            for i, x in enumerate(g): f[i] = f[i] * x % MOD
            NTT.intt(f)
            for j in range(d, min(d << 1, deg)):
                if f[j]: res[j] = MOD - f[j]
                else: res[j] = 0
            d <<= 1
        return res

    @classmethod
    def pow(cls, f: Poly, k: int, deg=-1) -> Poly:
        '''return: g s.t. g == f**k (mod x**deg)'''
        n = len(f)
        if deg == -1: deg = n
        if k == 0:
            if not deg: return []
            ret = [0] * deg
            ret[0] = 1
            return ret
        for i, x in enumerate(f):
            if x:
                rev = modinv(x, MOD)
                ret = cls.mul(cls.exp(cls.mul(cls.log(cls.mul(f, rev)[i:], deg),  k), deg), pow(x, k, MOD))
                ret[:0] = [0] * (i * k)
                if len(ret) < deg:
                    cls.resize(ret, deg)
                    return ret
                return ret[:deg]
            if (i + 1) * k >= deg: break
        return [0] * deg

    @staticmethod
    def exp(f: Poly, deg: int=-1) -> Poly:
        '''return: g s.t. log(g) == f (mod x ** deg)'''
        # assert(not self or self[0] == 0)
        if deg == -1: deg = len(f)
        inv = [0, 1]

        def integral(f: Poly) -> Poly:
            n = len(f)
            while len(inv) <= n:
                j, k = divmod(MOD, len(inv))
                inv.append((-inv[k] * j) % MOD)
            return [0] + [x * inv[i + 1] % MOD for i, x in enumerate(f)]

        def diff(f: Poly) -> Poly:
            return [x * i % MOD for i, x in enumerate(f) if i]

        b: Poly = [1, (f[1] if 1 < len(f) else 0)]
        c: Poly = [1]
        z1: Poly= []
        z2: Poly = [1, 1]
        m = 2
        while m < deg:
            y = b + [0] * m
            NTT.ntt(y)
            z1 = z2
            z = [y[i] * p % MOD for i, p in enumerate(z1)]
            NTT.intt(z)
            z[:m >> 1] = [0] * (m >> 1)
            NTT.ntt(z)
            for i, p in enumerate(z1): z[i] = z[i] * (-p) % MOD
            NTT.intt(z)
            c[m >> 1:] = z[m >> 1:]
            z2 = c + [0] * m
            NTT.ntt(z2)
            tmp = min(len(f), m)
            x = f[:tmp] + [0] * (m - tmp)
            x = diff(x)
            x.append(0)
            NTT.ntt(x)
            for i, p in enumerate(x): x[i] = y[i] * p % MOD
            NTT.intt(x)
            for i, p in enumerate(b):
                if not i: continue
                x[i - 1] -= p * i % MOD
            x += [0] * m
            for i in range(m - 1): x[m + i], x[i] = x[i], 0
            NTT.ntt(x)
            for i, p in enumerate(z2): x[i] = x[i] * p % MOD
            NTT.intt(x)
            x.pop()
            x = integral(x)
            x[:m] = [0] * m
            for i in range(m, min(len(f), m << 1)): x[i] += f[i]
            NTT.ntt(x)
            for i, p in enumerate(y): x[i] = x[i] * p % MOD
            NTT.intt(x)
            b[m:] = x[m:]
            m <<= 1
        return b[:deg]

    @classmethod
    def log(cls, f: Poly, deg=-1) -> Poly:
        '''return: g s.t. g == log(f) (mod x**deg)'''
        # assert(a[0] == 1)
        if deg == -1: deg = len(f)
        return cls.integral(cls.mul(cls.diff(f), cls.inv(f, deg))[:deg - 1])

    @staticmethod
    def integral(f: Poly) -> Poly:
        n = len(f)
        res = [0] * (n + 1)
        if n: res[1] = 1
        for i in range(2, n + 1):
            j, k = divmod(MOD, i)
            res[i] = (-res[k] * j) % MOD
        for i, x in enumerate(f): res[i + 1] = res[i + 1] * x % MOD
        return res

    @staticmethod
    def diff(f: Poly) -> Poly:
        '''return: dfdx'''
        return [i * x % MOD for i, x in enumerate(f) if i]

# https://nyaannyaan.github.io/library/fps/mod-pow.hpp
def mod_pow(k: int, base: Poly, d: Poly) -> Poly:
    assert(d)
    inv = FPS.inv(d[::-1])
    def quo(poly: Poly) -> Poly:
        if len(poly) < len(d): return []
        n = len(poly) - len(d) + 1
        return NTT.multiply(poly[:len(poly) - n - 1:-1], inv[:n])[n - 1::-1]

    res = [1]
    b = base[:]
    while k:
        if k & 1:
            res = NTT.multiply(res, b)
            res = FPS.sub(res, NTT.multiply(quo(res), d))
            FPS.shrink(res)
        b = NTT.pow2(b)
        b = FPS.sub(b, NTT.multiply(quo(b), d))
        FPS.shrink(b)
        k >>= 1
        # assert(len(b) + 1 <= len(d))
        # assert(len(res) + 1 <= len(d))
    return res

# https://nyaannyaan.github.io/library/matrix/black-box-linear-algebra.hpp
def inner_product(a: Poly, b: Poly) -> int:
    res = 0
    n = len(a)
    assert(n == len(b))
    for i in range(n): res += a[i] * b[i] % MOD
    return res % MOD

def random_poly(n: int) -> Poly:
    return [randint(0, MOD - 1) for _ in range(n)]

class ModMatrix:
    def __init__(self, n: int) -> None:
        self.mat = [[0] * n for _ in range(n)]

    def add(self, i: int, j: int, x: int) -> None:
        self.mat[i][j] += x

    def __mul__(self, r: Poly) -> Poly:
        assert(len(self.mat) == len(r))
        return [sum(matij * r[j] % MOD for j, matij in enumerate(mati)) % MOD for mati in self.mat]

    def apply(self, i: int, r: int) -> None:
        mati = self.mat[i]
        for j, matij in enumerate(mati):
            mati[j] = matij * r % MOD

class SparseMatrix:
    def __init__(self, n: int) -> None:
        self.mat: List[List[int]] = [[] for _ in range(n)]

    def add(self, i: int, j: int, x: int) -> None:
        self.mat[i].append(j << 30 | x)

    def __mul__(self, r: Poly) -> Poly:
        assert(len(self.mat) == len(r))
        return [sum((jx & 0x3fffffff) * r[jx >> 30] % MOD for jx in mati) % MOD for mati in self.mat]

    def apply(self, i: int, r: int) -> None:
        for idx, jx in enumerate(self.mat[i]):
            self.mat[i][idx] = (jx >> 30) << 30 | ((jx & 0x3fffffff) * r % MOD)

def vector_minpoly(b: List[Poly]) -> Poly:
    assert(b)
    n = len(b); m = len(b[0])
    u = random_poly(m)
    a = [0] * n
    for i, bi in enumerate(b): a[i] = inner_product(bi, u)
    return berlekamp_massey(a, MOD)

def mat_minpoly(A: Union[ModMatrix, SparseMatrix]) -> Poly:
    n = len(A.mat)
    u = random_poly(n)
    b: List[Poly] = [0] * (n << 1 | 1)
    for i in range(len(b)):
        b[i] = u
        u = A * u
    return vector_minpoly(b)

def fast_pow(A: Union[ModMatrix, SparseMatrix], b: Poly, k: int) -> Poly:
    n = len(b)
    mp = mat_minpoly(A)
    c = mod_pow(k, [0, 1], mp[::-1])
    res = [0] * n
    for ci in c:
        res = FPS.add(res, FPS.mul(b, ci))
        b = A * b
    return res

def fast_det(A: Union[ModMatrix, SparseMatrix]) -> int:
    n = len(A.mat)
    assert(n == len(A.mat))
    D = random_poly(n)
    while 1:
        while any([not x for x in D]): D = random_poly(n)
        AD = deepcopy(A)
        for i, d in enumerate(D): AD.apply(i, d)
        mp = mat_minpoly(AD)
        if mp[-1] == 0: return 0
        if len(mp) != n + 1: continue
        det = -mp[-1] if n & 1 else mp[-1]
        Ddet = 1
        for d in D: Ddet = Ddet * d % MOD
        return det * modinv(Ddet, MOD) % MOD
    exit(1)

# https://yukicoder.me/problems/no/1112
K, M, N = rd(), rd(), rd()
m = ModMatrix(K * K)
for i in range(M):
    p, q, r = rd() - 1, rd() - 1, rd() - 1
    m.add(p * K + q, q * K + r, 1)
b = [0] * (K * K)
for i in range(K): b[i * K] = 1
res = fast_pow(m, b, N - 2)
wtn(sum(res[:N]))
0