結果
問題 | No.287 場合の数 |
ユーザー | toyuzuko |
提出日時 | 2023-05-05 11:33:35 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 162 ms / 5,000 ms |
コード長 | 11,422 bytes |
コンパイル時間 | 225 ms |
コンパイル使用メモリ | 81,932 KB |
実行使用メモリ | 79,944 KB |
最終ジャッジ日時 | 2024-05-02 08:35:35 |
合計ジャッジ時間 | 4,772 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 158 ms
79,000 KB |
testcase_01 | AC | 70 ms
69,292 KB |
testcase_02 | AC | 74 ms
72,408 KB |
testcase_03 | AC | 162 ms
79,944 KB |
testcase_04 | AC | 150 ms
79,292 KB |
testcase_05 | AC | 148 ms
78,984 KB |
testcase_06 | AC | 75 ms
72,152 KB |
testcase_07 | AC | 148 ms
79,088 KB |
testcase_08 | AC | 150 ms
78,996 KB |
testcase_09 | AC | 154 ms
78,988 KB |
testcase_10 | AC | 155 ms
79,276 KB |
testcase_11 | AC | 133 ms
79,136 KB |
testcase_12 | AC | 152 ms
79,528 KB |
testcase_13 | AC | 131 ms
79,040 KB |
testcase_14 | AC | 145 ms
78,944 KB |
testcase_15 | AC | 149 ms
79,084 KB |
testcase_16 | AC | 148 ms
79,104 KB |
testcase_17 | AC | 153 ms
78,988 KB |
testcase_18 | AC | 146 ms
78,964 KB |
testcase_19 | AC | 145 ms
79,008 KB |
testcase_20 | AC | 143 ms
79,056 KB |
testcase_21 | AC | 151 ms
79,088 KB |
testcase_22 | AC | 153 ms
79,172 KB |
testcase_23 | AC | 149 ms
79,112 KB |
testcase_24 | AC | 143 ms
79,120 KB |
ソースコード
def primitive_root(m: int) -> int: if m == 2: return 1 if m == 167772161: return 3 if m == 469762049: return 3 if m == 754974721: return 11 if m == 998244353: return 3 divs = [0] * 20 divs[0] = 2 cnt = 1 x = (m - 1) // 2 while x % 2 == 0: x //= 2 i = 3 while i * i <= x: if x % i == 0: divs[cnt] = i cnt += 1 while x % i == 0: x //= i i += 2 if x > 1: divs[cnt] = x cnt += 1 g = 2 while True: for i in range(cnt): if pow(g, (m - 1) // divs[i], m) == 1: break else: return g g += 1 from typing import Sequence, Tuple def inv_gcd(a: int, b: int) -> Tuple[int, int]: a %= b if a == 0: return b, 0 s = b t = a m0 = 0 m1 = 1 while t: u = s // t s -= t * u m0 -= m1 * u s, t = t, s m0, m1 = m1, m0 if m0 < 0: m0 += b // s return s, m0 def crt(r: Sequence[int], m: Sequence[int]) -> Tuple[int, int]: assert len(r) == len(m) n = len(r) r0 = 0 m0 = 1 for i in range(n): assert 1 <= m[i] r1 = r[i] % m[i] m1 = m[i] if m0 < m1: r0, r1 = r1, r0 m0, m1 = m1, m0 if m0 % m1 == 0: if r0 % m1 != r1: return 0, 0 continue g, im = inv_gcd(m0, m1) u1 = m1 // g if (r1 - r0) % g: return 0, 0 x = (r1 - r0) // g * im % u1 r0 += x * m0 m0 *= u1 if (r0 < 0): r0 += m0 return r0, m0 def popcount(x: int) -> int: x = ((x >> 1) & 0x55555555) + (x & 0x55555555) x = ((x >> 2) & 0x33333333) + (x & 0x33333333) x = ((x >> 4) & 0x0f0f0f0f) + (x & 0x0f0f0f0f) x = ((x >> 8) & 0x00ff00ff) + (x & 0x00ff00ff) x = ((x >> 16) & 0x0000ffff) + (x & 0x0000ffff) return x def tzcount(x: int) -> int: return popcount(~x & (x - 1)) from typing import List, Callable, Union, Optional class Convolution(): def __init__(self, mod: Union[Callable[[], int], int]) -> None: if isinstance(mod, int): self.mod = lambda: mod else: self.mod = mod if self.mod() >= (1 << 31): raise ValueError("given mod is too large. use ArbitraryModConvolution") g = primitive_root(self.mod()) self.rank2 = rank2 = tzcount(self.mod() - 1) self.root = root = [0] * (rank2 + 1) self.iroot = iroot = [0] * (rank2 + 1) self.rate2 = rate2 = [0] * max(0, rank2 - 1) self.irate2 = irate2 = [0] * max(0, rank2 - 1) self.rate3 = rate3 = [0] * max(0, rank2 - 2) self.irate3 = irate3 = [0] * max(0, rank2 - 2) root[rank2] = pow(g, (self.mod() - 1) >> rank2, self.mod()) iroot[rank2] = pow(root[rank2], self.mod() - 2, self.mod()) for i in range(rank2)[::-1]: root[i] = root[i + 1] * root[i + 1] % self.mod() iroot[i] = iroot[i + 1] * iroot[i + 1] % self.mod() prod = 1 iprod = 1 for i in range(rank2 - 1): rate2[i] = root[i + 2] * prod % self.mod() irate2[i] = iroot[i + 2] * iprod % self.mod() prod *= iroot[i + 2] prod %= self.mod() iprod *= root[i + 2] iprod %= self.mod() prod = 1 iprod = 1 for i in range(rank2 - 2): rate3[i] = root[i + 3] * prod % self.mod() irate3[i] = iroot[i + 3] * iprod % self.mod() prod *= iroot[i + 3] prod %= self.mod() iprod *= root[i + 3] iprod %= self.mod() self.imag = root[2] self.iimag = iroot[2] def butterfly(self, a: List[int]) -> None: n = len(a) h = (n - 1).bit_length() len_ = 0 while len_ < h: if h - len_ == 1: p = 1 << (h - len_ - 1) rot = 1 for s in range(1 << len_): offset = s << (h - len_) for i in range(p): l = a[i + offset] r = a[i + offset + p] * rot % self.mod() a[i + offset] = (l + r) % self.mod() a[i + offset + p] = (l - r) % self.mod() if s + 1 != 1 << len_: rot *= self.rate2[(~s & -~s).bit_length() - 1] rot %= self.mod() len_ += 1 else: p = 1 << (h - len_ - 2) rot = 1 for s in range(1 << len_): rot2 = rot * rot % self.mod() rot3 = rot2 * rot % self.mod() offset = s << (h - len_) for i in range(p): a0 = a[i + offset] a1 = a[i + offset + p] * rot a2 = a[i + offset + p * 2] * rot2 a3 = a[i + offset + p * 3] * rot3 a1na3imag = (a1 - a3) % self.mod() * self.imag a[i + offset] = (a0 + a2 + a1 + a3) % self.mod() a[i + offset + p] = (a0 + a2 - a1 - a3) % self.mod() a[i + offset + p * 2] = (a0 - a2 + a1na3imag) % self.mod() a[i + offset + p * 3] = (a0 - a2 - a1na3imag) % self.mod() if s + 1 != 1 << len_: rot *= self.rate3[(~s & -~s).bit_length() - 1] rot %= self.mod() len_ += 2 def butterfly_inv(self, a: List[int]) -> None: n = len(a) h = (n - 1).bit_length() len_ = h while len_: if len_ == 1: p = 1 << (h - len_) irot = 1 for s in range(1 << (len_ - 1)): offset = s << (h - len_ + 1) for i in range(p): l = a[i + offset] r = a[i + offset + p] a[i + offset] = (l + r) % self.mod() a[i + offset + p] = (l - r) * irot % self.mod() if s + 1 != (1 << (len_ - 1)): irot *= self.irate2[(~s & -~s).bit_length() - 1] irot %= self.mod() len_ -= 1 else: p = 1 << (h - len_) irot = 1 for s in range(1 << (len_ - 2)): irot2 = irot * irot % self.mod() irot3 = irot2 * irot % self.mod() offset = s << (h - len_ + 2) for i in range(p): a0 = a[i + offset] a1 = a[i + offset + p] a2 = a[i + offset + p * 2] a3 = a[i + offset + p * 3] a2na3iimag = (a2 - a3) * self.iimag % self.mod() a[i + offset] = (a0 + a1 + a2 + a3) % self.mod() a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % self.mod() a[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % self.mod() a[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % self.mod() if s + 1 != (1 << (len_ - 2)): irot *= self.irate3[(~s & -~s).bit_length() - 1] irot %= self.mod() len_ -= 2 def convolution(self, a: List[int], b: List[int]) -> List[int]: a, b = a.copy(), b.copy() n, m = len(a), len(b) if n + m - 1 > (1 << self.rank2): raise ValueError('rank2 of given mod is too small. use ArbitraryModConvolution') if not n or not m: return [] if min(n, m) <= 100: if n < m: n, m = m, n a, b = b, a res = [0] * (n + m - 1) for i in range(n): for j in range(m): res[i + j] += a[i] * b[j] res[i + j] %= self.mod() return res z = 1 << (n + m - 2).bit_length() a += [0] * (z - n) b += [0] * (z - m) self.butterfly(a) self.butterfly(b) for i in range(z): a[i] *= b[i] a[i] %= self.mod() self.butterfly_inv(a) a = a[:n + m - 1] iz = pow(z, self.mod() - 2, self.mod()) for i in range(n + m - 1): a[i] *= iz a[i] %= self.mod() return a class ArbitraryModConvolution(): def __init__(self, mod: Union[Callable[[], int], int], max_size: int = 2**20, fmt_mods: Optional[List[Callable[[], int]]] = None) -> None: if isinstance(mod, int): self.mod = lambda: mod else: self.mod = mod if fmt_mods is None: MODs = [lambda: 998244353, # 119 * 2^23 + 1 lambda: 943718401, # 225 * 2^22 + 1 lambda: 918552577, # 219 * 2^22 + 1 lambda: 924844033, # 441 * 2^21 + 1 lambda: 985661441 # 235 * 2^22 + 1 ] self.mods = [] mul = 1 for MOD in MODs: mul *= MOD() self.mods.append(MOD()) if mul > max_size * self.mod() * self.mod(): break else: raise ValueError("given mod is too large") self.convs = [Convolution(MOD) for MOD in self.mods] self.minrank2 = min([conv.rank2 for conv in self.convs]) else: self.mods = [] mul = 1 for MOD in fmt_mods: if not callable(MOD) or not isinstance(MOD(), int): raise TypeError("fmt_mods must be a list of functions that return int") else: mul *= MOD() self.mods.append(MOD()) if mul < max_size * self.mod() * self.mod(): raise ValueError("the product of fmt_mods is too small. add another mod to fmt_mods") self.convs = [Convolution(MOD) for MOD in fmt_mods] self.minrank2 = min([conv.rank2 for conv in self.convs]) def convolution(self, a: List[int], b: List[int]) -> List[int]: n = len(a) m = len(b) if n + m - 1 > (1 << self.minrank2): raise ValueError('the lengths of given arrays is too large or the minimum rank2 for fmt_mods is too small. use difference mods') if not n or not m: return [] if min(n, m) <= 100: if n < m: n, m = m, n a, b = b, a res = [0] * (n + m - 1) for i in range(n): for j in range(m): res[i + j] += a[i] * b[j] res[i + j] %= self.mod() return res cs = [self.convs[i].convolution([v % self.mods[i] for v in a], [v % self.mods[i] for v in b]) for i in range(len(self.mods))] res = [0] * (n + m - 1) mods = [self.mods[i] for i in range(len(self.mods))] for i, v in enumerate(zip(*cs)): cr, cm = crt(v, mods) res[i] = cr % self.mod() return res N = int(input()) arr = [1] * (N + 1) conv = ArbitraryModConvolution(2721355068691 + 1) for i in range(3): arr = conv.convolution(arr, arr) print(arr[6 * N])