結果
問題 | No.287 場合の数 |
ユーザー |
![]() |
提出日時 | 2022-12-20 18:31:58 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 389 ms / 5,000 ms |
コード長 | 8,365 bytes |
コンパイル時間 | 207 ms |
コンパイル使用メモリ | 82,176 KB |
実行使用メモリ | 80,776 KB |
最終ジャッジ日時 | 2024-11-18 01:58:12 |
合計ジャッジ時間 | 8,941 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 22 |
ソースコード
MOD1 = 998244353MOD2 = 985661441MOD3 = 943718401MOD4 = 935329793MOD5 = 918552577mod1 = lambda : MOD1mod2 = lambda : MOD2mod3 = lambda : MOD3mod4 = lambda : MOD4mod5 = lambda : MOD5def primitive_root(m):if m == 2: return 1if m == 167772161: return 3if m == 469762049: return 3if m == 754974721: return 11if m == 998244353: return 3divs = [0] * 20divs[0] = 2cnt = 1x = (m - 1) // 2while x % 2 == 0: x //= 2i = 3while i * i <= x:if x % i == 0:divs[cnt] = icnt += 1while x % i == 0: x //= ii += 2if x > 1:divs[cnt] = xcnt += 1g = 2while True:for i in range(cnt):if pow(g, (m - 1) // divs[i], m) == 1: breakelse:return gg += 1def popcount(x):x = ((x >> 1) & 0x55555555) + (x & 0x55555555)x = ((x >> 2) & 0x33333333) + (x & 0x33333333)x = ((x >> 4) & 0x0f0f0f0f) + (x & 0x0f0f0f0f)x = ((x >> 8) & 0x00ff00ff) + (x & 0x00ff00ff)x = ((x >> 16) & 0x0000ffff) + (x & 0x0000ffff)return xdef tzcount(x):return popcount(~x & (x - 1))def build_ntt(mod):g = primitive_root(mod())rank2 = tzcount(mod() - 1)root = [0] * (rank2 + 1)iroot = [0] * (rank2 + 1)rate2 = [0] * max(0, rank2 - 1)irate2 = [0] * max(0, rank2 - 1)rate3 = [0] * max(0, rank2 - 2)irate3 = [0] * max(0, rank2 - 2)root[rank2] = pow(g, (mod() - 1) >> rank2, mod())iroot[rank2] = pow(root[rank2], mod() - 2, mod())for i in range(rank2)[::-1]:root[i] = root[i + 1] * root[i + 1]root[i] %= mod()iroot[i] = iroot[i + 1] * iroot[i + 1]iroot[i] %= mod()prod = 1iprod = 1for i in range(rank2 - 1):rate2[i] = root[i + 2] * prod % mod()irate2[i] = iroot[i + 2] * iprod % mod()prod *= iroot[i + 2]prod %= mod()iprod *= root[i + 2]iprod %= mod()prod = 1iprod = 1for i in range(rank2 - 2):rate3[i] = root[i + 3] * prod % mod()irate3[i] = iroot[i + 3] * iprod % mod()prod *= iroot[i + 3]prod %= mod()iprod *= root[i + 3]iprod %= mod()return root, iroot, rate2, irate2, rate3, irate3def butterfly(a, mod, rate2, irate2, rate3, irate3, imag, iimag):n = len(a)h = (n - 1).bit_length()len_ = 0while len_ < h:if h - len_ == 1:p = 1 << (h - len_ - 1)rot = 1for s in range(1 << len_):offset = s << (h - len_)for i in range(p):l = a[i + offset]r = a[i + offset + p] * rot % mod()a[i + offset] = (l + r) % mod()a[i + offset + p] = (l - r) % mod()if s + 1 != 1 << len_:rot *= rate2[(~s & -~s).bit_length() - 1]rot %= mod()len_ += 1else:p = 1 << (h - len_ - 2)rot = 1for s in range(1 << len_):rot2 = rot * rot % mod()rot3 = rot2 * rot % mod()offset = s << (h - len_)for i in range(p):a0 = a[i + offset]a1 = a[i + offset + p] * rota2 = a[i + offset + p * 2] * rot2a3 = a[i + offset + p * 3] * rot3a1na3imag = (a1 - a3) % mod() * imaga[i + offset] = (a0 + a2 + a1 + a3) % mod()a[i + offset + p] = (a0 + a2 - a1 - a3) % mod()a[i + offset + p * 2] = (a0 - a2 + a1na3imag) % mod()a[i + offset + p * 3] = (a0 - a2 - a1na3imag) % mod()if s + 1 != 1 << len_:rot *= rate3[(~s & -~s).bit_length() - 1]rot %= mod()len_ += 2def butterfly_inv(a, mod, rate2, irate2, rate3, irate3, imag, iimag):n = len(a)h = (n - 1).bit_length()len_ = hwhile len_:if len_ == 1:p = 1 << (h - len_)irot = 1for s in range(1 << (len_ - 1)):offset = s << (h - len_ + 1)for i in range(p):l = a[i + offset]r = a[i + offset + p]a[i + offset] = (l + r) % mod()a[i + offset + p] = (l - r) * irot % mod()if s + 1 != (1 << (len_ - 1)):irot *= irate2[(~s & -~s).bit_length() - 1]irot %= mod()len_ -= 1else:p = 1 << (h - len_)irot = 1for s in range(1 << (len_ - 2)):irot2 = irot * irot % mod()irot3 = irot2 * irot % mod()offset = s << (h - len_ + 2)for i in range(p):a0 = a[i + offset]a1 = a[i + offset + p]a2 = a[i + offset + p * 2]a3 = a[i + offset + p * 3]a2na3iimag = (a2 - a3) * iimag % mod()a[i + offset] = (a0 + a1 + a2 + a3) % mod()a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % mod()a[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % mod()a[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % mod()if s + 1 != (1 << (len_ - 2)):irot *= irate3[(~s & -~s).bit_length() - 1]irot %= mod()len_ -= 2def convolution(a, b, mod):root, iroot, rate2, irate2, rate3, irate3 = build_ntt(mod)imag = root[2]iimag = iroot[2]n = len(a)m = len(b)if not n or not m: return []if min(n, m) <= 100:if n < m:n, m = m, na, b = b, ares = [0] * (n + m - 1)for i in range(n):for j in range(m):res[i + j] += a[i] * b[j]res[i + j] %= mod()return resz = 1 << (n + m - 2).bit_length()a += [0] * (z - n)b += [0] * (z - m)butterfly(a, mod, rate2, irate2, rate3, irate3, imag, iimag)butterfly(b, mod, rate2, irate2, rate3, irate3, imag, iimag)for i in range(z):a[i] *= b[i]a[i] %= mod()butterfly_inv(a, mod, rate2, irate2, rate3, irate3, imag, iimag)a = a[:n + m - 1]iz = pow(z, mod() - 2, mod())for i in range(n + m - 1):a[i] *= iza[i] %= mod()return adef inv_gcd(a, b):a %= bif a == 0: return b, 0s = bt = am0 = 0m1 = 1while t:u = s // ts -= t * um0 -= m1 * us, t = t, sm0, m1 = m1, m0if m0 < 0: m0 += b // sreturn s, m0def gcd(x, y):while y:x, y = y, x % yreturn xdef crt(r, m):assert len(r) == len(m)n = len(r)r0 = 0m0 = 1for i in range(n):assert 1 <= m[i]r1 = r[i] % m[i]m1 = m[i]if m0 < m1:r0, r1 = r1, r0m0, m1 = m1, m0if m0 % m1 == 0:if r0 % m1 != r1: return 0, 0continueg, im = inv_gcd(m0, m1)u1 = m1 // gif (r1 - r0) % g: return 0, 0x = (r1 - r0) // g * im % u1r0 += x * m0m0 *= u1if (r0 < 0): r0 += m0return r0, m0def convolution_64bit(a, b):n = len(a)m = len(b)mask = 18446744073709551615mods = (MOD1, MOD2, MOD3, MOD4, MOD5)c1 = convolution([v % MOD1 for v in a], [v % MOD1 for v in b], mod1)[:n + m - 1]c2 = convolution([v % MOD2 for v in a], [v % MOD2 for v in b], mod2)[:n + m - 1]c3 = convolution([v % MOD3 for v in a], [v % MOD3 for v in b], mod3)[:n + m - 1]c4 = convolution([v % MOD4 for v in a], [v % MOD4 for v in b], mod4)[:n + m - 1]c5 = convolution([v % MOD5 for v in a], [v % MOD5 for v in b], mod5)[:n + m - 1]res = [0] * (n + m - 1)for i, v in enumerate(zip(c1, c2, c3, c4, c5)):cr, cm = crt(v, mods)res[i] = cr & maskreturn resn = int(input())x = [1] * (n + 1) + [0] * (5 * n)ans = x[:]for _ in range(7):ans = convolution_64bit(ans, x[:])[:6 * n + 1]print(ans[6 * n])