結果

問題 No.287 場合の数
ユーザー toyuzukotoyuzuko
提出日時 2023-05-05 11:33:35
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 162 ms / 5,000 ms
コード長 11,422 bytes
コンパイル時間 225 ms
コンパイル使用メモリ 81,932 KB
実行使用メモリ 79,944 KB
最終ジャッジ日時 2024-05-02 08:35:35
合計ジャッジ時間 4,772 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 158 ms
79,000 KB
testcase_01 AC 70 ms
69,292 KB
testcase_02 AC 74 ms
72,408 KB
testcase_03 AC 162 ms
79,944 KB
testcase_04 AC 150 ms
79,292 KB
testcase_05 AC 148 ms
78,984 KB
testcase_06 AC 75 ms
72,152 KB
testcase_07 AC 148 ms
79,088 KB
testcase_08 AC 150 ms
78,996 KB
testcase_09 AC 154 ms
78,988 KB
testcase_10 AC 155 ms
79,276 KB
testcase_11 AC 133 ms
79,136 KB
testcase_12 AC 152 ms
79,528 KB
testcase_13 AC 131 ms
79,040 KB
testcase_14 AC 145 ms
78,944 KB
testcase_15 AC 149 ms
79,084 KB
testcase_16 AC 148 ms
79,104 KB
testcase_17 AC 153 ms
78,988 KB
testcase_18 AC 146 ms
78,964 KB
testcase_19 AC 145 ms
79,008 KB
testcase_20 AC 143 ms
79,056 KB
testcase_21 AC 151 ms
79,088 KB
testcase_22 AC 153 ms
79,172 KB
testcase_23 AC 149 ms
79,112 KB
testcase_24 AC 143 ms
79,120 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

def primitive_root(m: int) -> int:
    if m == 2: return 1
    if m == 167772161: return 3
    if m == 469762049: return 3
    if m == 754974721: return 11
    if m == 998244353: return 3
    divs = [0] * 20
    divs[0] = 2
    cnt = 1
    x = (m - 1) // 2
    while x % 2 == 0: x //= 2
    i = 3
    while i * i <= x:
        if x % i == 0:
            divs[cnt] = i
            cnt += 1
            while x % i == 0: x //= i
        i += 2
    if x > 1:
        divs[cnt] = x
        cnt += 1
    g = 2
    while True:
        for i in range(cnt):
            if pow(g, (m - 1) // divs[i], m) == 1: break
        else:
            return g
        g += 1

from typing import Sequence, Tuple

def inv_gcd(a: int, b: int) -> Tuple[int, int]:
    a %= b
    if a == 0: return b, 0
    s = b
    t = a
    m0 = 0
    m1 = 1
    while t:
        u = s // t
        s -= t * u
        m0 -= m1 * u
        s, t = t, s
        m0, m1 = m1, m0
    if m0 < 0: m0 += b // s
    return s, m0

def crt(r: Sequence[int], m: Sequence[int]) -> Tuple[int, int]:
    assert len(r) == len(m)
    n = len(r)
    r0 = 0
    m0 = 1
    for i in range(n):
        assert 1 <= m[i]
        r1 = r[i] % m[i]
        m1 = m[i]
        if m0 < m1:
            r0, r1 = r1, r0
            m0, m1 = m1, m0
        if m0 % m1 == 0:
            if r0 % m1 != r1: return 0, 0
            continue
        g, im = inv_gcd(m0, m1)
        u1 = m1 // g
        if (r1 - r0) % g: return 0, 0
        x = (r1 - r0) // g * im % u1
        r0 += x * m0
        m0 *= u1
        if (r0 < 0): r0 += m0
    return r0, m0

def popcount(x: int) -> int:
    x = ((x >> 1)  & 0x55555555) + (x & 0x55555555)
    x = ((x >> 2)  & 0x33333333) + (x & 0x33333333)
    x = ((x >> 4)  & 0x0f0f0f0f) + (x & 0x0f0f0f0f)
    x = ((x >> 8)  & 0x00ff00ff) + (x & 0x00ff00ff)
    x = ((x >> 16) & 0x0000ffff) + (x & 0x0000ffff)
    return x

def tzcount(x: int) -> int:
    return popcount(~x & (x - 1))

from typing import List, Callable, Union, Optional

class Convolution():
    def __init__(self, mod: Union[Callable[[], int], int]) -> None:
        if isinstance(mod, int):
            self.mod = lambda: mod
        else:
            self.mod = mod
        if self.mod() >= (1 << 31):
            raise ValueError("given mod is too large. use ArbitraryModConvolution")
        g = primitive_root(self.mod())
        self.rank2 = rank2 = tzcount(self.mod() - 1)
        self.root = root = [0] * (rank2 + 1)
        self.iroot = iroot = [0] * (rank2 + 1)
        self.rate2 = rate2 = [0] * max(0, rank2 - 1)
        self.irate2 = irate2 = [0] * max(0, rank2 - 1)
        self.rate3 = rate3 = [0] * max(0, rank2 - 2)
        self.irate3 = irate3 = [0] * max(0, rank2 - 2)
        root[rank2] = pow(g, (self.mod() - 1) >> rank2, self.mod())
        iroot[rank2] = pow(root[rank2], self.mod() - 2, self.mod())
        for i in range(rank2)[::-1]:
            root[i] = root[i + 1] * root[i + 1] % self.mod()
            iroot[i] = iroot[i + 1] * iroot[i + 1] % self.mod()
        prod = 1
        iprod = 1
        for i in range(rank2 - 1):
            rate2[i] = root[i + 2] * prod % self.mod()
            irate2[i] = iroot[i + 2] * iprod % self.mod()
            prod *= iroot[i + 2]
            prod %= self.mod()
            iprod *= root[i + 2]
            iprod %= self.mod()
        prod = 1
        iprod = 1
        for i in range(rank2 - 2):
            rate3[i] = root[i + 3] * prod % self.mod()
            irate3[i] = iroot[i + 3] * iprod % self.mod()
            prod *= iroot[i + 3]
            prod %= self.mod()
            iprod *= root[i + 3]
            iprod %= self.mod()
        self.imag = root[2]
        self.iimag = iroot[2]

    def butterfly(self, a: List[int]) -> None:
        n = len(a)
        h = (n - 1).bit_length()
        len_ = 0
        while len_ < h:
            if h - len_ == 1:
                p = 1 << (h - len_ - 1)
                rot = 1
                for s in range(1 << len_):
                    offset = s << (h - len_)
                    for i in range(p):
                        l = a[i + offset]
                        r = a[i + offset + p] * rot % self.mod()
                        a[i + offset] = (l + r) % self.mod()
                        a[i + offset + p] = (l - r) % self.mod()
                    if s + 1 != 1 << len_:
                        rot *= self.rate2[(~s & -~s).bit_length() - 1]
                        rot %= self.mod()
                len_ += 1
            else:
                p = 1 << (h - len_ - 2)
                rot = 1
                for s in range(1 << len_):
                    rot2 = rot * rot % self.mod()
                    rot3 = rot2 * rot % self.mod()
                    offset = s << (h - len_)
                    for i in range(p):
                        a0 = a[i + offset]
                        a1 = a[i + offset + p] * rot
                        a2 = a[i + offset + p * 2] * rot2
                        a3 = a[i + offset + p * 3] * rot3
                        a1na3imag = (a1 - a3) % self.mod() * self.imag
                        a[i + offset] = (a0 + a2 + a1 + a3) % self.mod()
                        a[i + offset + p] = (a0 + a2 - a1 - a3) % self.mod()
                        a[i + offset + p * 2] = (a0 - a2 + a1na3imag) % self.mod()
                        a[i + offset + p * 3] = (a0 - a2 - a1na3imag) % self.mod()
                    if s + 1 != 1 << len_:
                        rot *= self.rate3[(~s & -~s).bit_length() - 1]
                        rot %= self.mod()
                len_ += 2

    def butterfly_inv(self, a: List[int]) -> None:
        n = len(a)
        h = (n - 1).bit_length()
        len_ = h
        while len_:
            if len_ == 1:
                p = 1 << (h - len_)
                irot = 1
                for s in range(1 << (len_ - 1)):
                    offset = s << (h - len_ + 1)
                    for i in range(p):
                        l = a[i + offset]
                        r = a[i + offset + p]
                        a[i + offset] = (l + r) % self.mod()
                        a[i + offset + p] = (l - r) * irot % self.mod()
                    if s + 1 != (1 << (len_ - 1)):
                        irot *= self.irate2[(~s & -~s).bit_length() - 1]
                        irot %= self.mod()
                len_ -= 1
            else:
                p = 1 << (h - len_)
                irot = 1
                for s in range(1 << (len_ - 2)):
                    irot2 = irot * irot % self.mod()
                    irot3 = irot2 * irot % self.mod()
                    offset = s << (h - len_ + 2)
                    for i in range(p):
                        a0 = a[i + offset]
                        a1 = a[i + offset + p]
                        a2 = a[i + offset + p * 2]
                        a3 = a[i + offset + p * 3]
                        a2na3iimag = (a2 - a3) * self.iimag % self.mod()
                        a[i + offset] = (a0 + a1 + a2 + a3) % self.mod()
                        a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % self.mod()
                        a[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % self.mod()
                        a[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % self.mod()
                    if s + 1 != (1 << (len_ - 2)):
                        irot *= self.irate3[(~s & -~s).bit_length() - 1]
                        irot %= self.mod()
                len_ -= 2

    def convolution(self, a: List[int], b: List[int]) -> List[int]:
        a, b = a.copy(), b.copy()
        n, m = len(a), len(b)
        if n + m - 1 > (1 << self.rank2):
            raise ValueError('rank2 of given mod is too small. use ArbitraryModConvolution')
        if not n or not m: return []
        if min(n, m) <= 100:
            if n < m:
                n, m = m, n
                a, b = b, a
            res = [0] * (n + m - 1)
            for i in range(n):
                for j in range(m):
                    res[i + j] += a[i] * b[j]
                    res[i + j] %= self.mod()
            return res
        z = 1 << (n + m - 2).bit_length()
        a += [0] * (z - n)
        b += [0] * (z - m)
        self.butterfly(a)
        self.butterfly(b)
        for i in range(z):
            a[i] *= b[i]
            a[i] %= self.mod()
        self.butterfly_inv(a)
        a = a[:n + m - 1]
        iz = pow(z, self.mod() - 2, self.mod())
        for i in range(n + m - 1):
            a[i] *= iz
            a[i] %= self.mod()
        return a

class ArbitraryModConvolution():
    def __init__(self, mod: Union[Callable[[], int], int], max_size: int = 2**20, fmt_mods: Optional[List[Callable[[], int]]] = None) -> None:
        if isinstance(mod, int):
            self.mod = lambda: mod
        else:
            self.mod = mod
        if fmt_mods is None:
            MODs = [lambda: 998244353, # 119 * 2^23 + 1
                    lambda: 943718401, # 225 * 2^22 + 1
                    lambda: 918552577, # 219 * 2^22 + 1
                    lambda: 924844033, # 441 * 2^21 + 1
                    lambda: 985661441  # 235 * 2^22 + 1
                    ]
            self.mods = []
            mul = 1
            for MOD in MODs:
                mul *= MOD()
                self.mods.append(MOD())
                if mul > max_size * self.mod() * self.mod():
                    break
            else:
                raise ValueError("given mod is too large") 
            self.convs = [Convolution(MOD) for MOD in self.mods]
            self.minrank2 = min([conv.rank2 for conv in self.convs])
        else:
            self.mods = []
            mul = 1
            for MOD in fmt_mods:
                if not callable(MOD) or not isinstance(MOD(), int):
                    raise TypeError("fmt_mods must be a list of functions that return int")
                else:
                    mul *= MOD()
                    self.mods.append(MOD())
            if mul < max_size * self.mod() * self.mod():
                raise ValueError("the product of fmt_mods is too small. add another mod to fmt_mods")
            self.convs = [Convolution(MOD) for MOD in fmt_mods]
            self.minrank2 = min([conv.rank2 for conv in self.convs])

    def convolution(self, a: List[int], b: List[int]) -> List[int]:
        n = len(a)
        m = len(b)
        if n + m - 1 > (1 << self.minrank2):
            raise ValueError('the lengths of given arrays is too large or the minimum rank2 for fmt_mods is too small. use difference mods')
        if not n or not m: return []
        if min(n, m) <= 100:
            if n < m:
                n, m = m, n
                a, b = b, a
            res = [0] * (n + m - 1)
            for i in range(n):
                for j in range(m):
                    res[i + j] += a[i] * b[j]
                    res[i + j] %= self.mod()
            return res
        cs = [self.convs[i].convolution([v % self.mods[i] for v in a], [v % self.mods[i] for v in b]) for i in range(len(self.mods))]
        res = [0] * (n + m - 1)
        mods = [self.mods[i] for i in range(len(self.mods))]
        for i, v in enumerate(zip(*cs)):
            cr, cm = crt(v, mods)
            res[i] = cr % self.mod()
        return res
    
N = int(input())
arr = [1] * (N + 1)

conv = ArbitraryModConvolution(2721355068691 + 1)

for i in range(3):
    arr = conv.convolution(arr, arr)

print(arr[6 * N])
0