結果
問題 |
No.3080 Colonies on Line
|
ユーザー |
![]() |
提出日時 | 2025-03-28 21:36:16 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 9,994 bytes |
コンパイル時間 | 333 ms |
コンパイル使用メモリ | 82,148 KB |
実行使用メモリ | 139,696 KB |
最終ジャッジ日時 | 2025-03-28 21:36:31 |
合計ジャッジ時間 | 10,316 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 8 TLE * 1 -- * 26 |
ソースコード
mod = 998244353 n = 10 ** 6 inv = [1 for j in range(n + 1)] for a in range(2, n + 1): inv[a] = (mod - inv[mod % a]) * (mod // a) % mod def mod_inv(a, mod = 998244353): if mod == 1: return 0 a %= mod b, s, t = mod, 1, 0 while True: if a == 1: return s t -= (b // a) * s b %= a if b == 1: return t + mod s -= (a // b) * t a %= b fact = [1 for i in range(n + 1)] for i in range(1, n + 1): fact[i] = fact[i - 1] * i % mod fact_inv = [1 for i in range(n + 1)] fact_inv[-1] = mod_inv(fact[-1], mod) for i in range(n, 0, -1): fact_inv[i - 1] = fact_inv[i] * i % mod def binom(n, r): if n < r or n < 0 or r < 0: return 0 res = fact_inv[n - r] * fact_inv[r] % mod res = res * fact[n] % mod return res def Garner(Rem, MOD, mod): Mod = MOD[:] Rem.append(0) Mod.append(mod) n = len(Mod) coffs = [1] * n constants = [0] * n for i in range(n - 1): v = (Rem[i] - constants[i]) * mod_inv(coffs[i], Mod[i]) % Mod[i] for j in range(i + 1, n): constants[j] = (constants[j] + coffs[j] * v) % Mod[j] coffs[j] = (coffs[j] * Mod[i]) % Mod[j] return constants[-1] import random def Tonelli_Shanks(a, p = 998244353): a %= p if a < 2: return a if pow(a, (p - 1) // 2, p) != 1: return -1 if p % 4 == 3: return pow(a, (p + 1) // 4, p) b = 1 if p == 998244353: b = 3 else: while pow(b, (p - 1) // 2, p) == 1: b = random.randint(2, p - 1) q = p - 1 Q = 0 while q % 2 == 0: Q, q = Q + 1, q // 2 x = pow(a, (q + 1) // 2, p) b = pow(b, q, p) shift = 2 while x * x % p != a: error = pow(a, -1, p) * x * x % p if pow(error, 1 << (Q - shift), p) != 1: x = x * b % p b = b * b % p shift += 1 return x def NTT_info(mod): if mod == 998244353: return (23, 31, 0) if mod == 120586241: return (20, 74066978, 1) if mod == 167772161: return (25, 17, 2) if mod == 469762049: return (26, 30, 3) if mod == 754974721: return (24, 362, 4) if mod == 880803841: return (23, 211, 5) if mod == 924844033: return (21, 44009197, 6) if mod == 943718401: return (22, 663003469, 7) if mod == 1045430273: return (20, 363, 8) if mod == 1051721729: return (20, 330, 9) if mod == 1053818881: return (20, 2789, 10) return (0, -1, -1) def prepared_fft(mod = 998244353): rank2 = NTT_info(mod)[0] root, iroot = [0] * 30, [0] * 30 rate2, irate2 = [0] * 30, [0] * 30 rate3, irate3 = [0] * 30, [0] * 30 root[rank2] = NTT_info(mod)[1] iroot[rank2] = pow(root[rank2], mod - 2, mod) for i in range(rank2 - 1, -1, -1): root[i] = root[i + 1] * root[i + 1] % mod iroot[i] = iroot[i + 1] * iroot[i + 1] % mod prod, iprod = 1, 1 for i in range(rank2 - 1): rate2[i] = root[i + 2] * prod % mod irate2[i] = iroot[i + 2] * iprod % mod prod = prod * iroot[i + 2] % mod iprod = iprod * root[i + 2] % mod prod, iprod = 1, 1 for i in range(rank2 - 2): rate3[i] = root[i + 3] * prod % mod irate3[i] = iroot[i + 3] * iprod % mod prod = prod * iroot[i + 3] % mod iprod = iprod * root[i + 3] % mod return root, iroot, rate2, irate2, rate3, irate3 root, iroot, rate2, irate2, rate3, irate3 = [[] for _ in range(11)], [[] for _ in range(11)], [[] for _ in range(11)], [[] for _ in range(11)], [[] for _ in range(11)], [[] for _ in range(11)] def ntt(a, inverse = 0, mod = 998244353): idx = NTT_info(mod)[2] if len(root[idx]) == 0: root[idx], iroot[idx], rate2[idx], irate2[idx], rate3[idx], irate3[idx] = prepared_fft(mod) n = len(a) h = (n - 1).bit_length() assert (n == 1 << h) if inverse == 0: le = 0 while le < h: if h - le == 1: p = 1 << (h - le - 1) rot = 1 for s in range(1 << le): offset = s << (h - le) for i in range(p): l = a[i + offset] r = a[i + offset + p] * rot % mod a[i + offset] = (l + r) % mod a[i + offset + p] = (l - r) % mod rot = rot * rate2[idx][((~s & -~s) - 1).bit_length()] % mod le += 1 else: p = 1 << (h - le - 2) rot, imag = 1, root[idx][2] for s in range(1 << le): rot2 = rot * rot % mod rot3 = rot2 * rot % mod offset = s << (h - le) for i in range(p): a0 = a[i + offset] a1 = a[i + offset + p] * rot a2 = a[i + offset + p * 2] * rot2 a3 = a[i + offset + p * 3] * rot3 a1na3imag = (a1 - a3) % mod * imag a[i + offset] = (a0 + a2 + a1 + a3) % mod a[i + offset + p] = (a0 + a2 - a1 - a3) % mod a[i + offset + p * 2] = (a0 - a2 + a1na3imag) % mod a[i + offset + p * 3] = (a0 - a2 - a1na3imag) % mod rot = rot * rate3[idx][((~s & -~s) - 1).bit_length()] % mod le += 2 else: coef = pow(n, mod - 2, mod) for i in range(n): a[i] = a[i] * coef % mod le = h while le: if le == 1: p = 1 << (h - le) irot = 1 for s in range(1 << (le - 1)): offset = s << (h - le + 1) for i in range(p): l = a[i + offset] r = a[i + offset + p] a[i + offset] = (l + r) % mod a[i + offset + p] = (l - r) * irot % mod irot = irot * irate2[idx][((~s & -~s) - 1).bit_length()] % mod le -= 1 else: p = 1 << (h - le) irot, iimag = 1, iroot[idx][2] for s in range(1 << (le - 2)): irot2 = irot * irot % mod irot3 = irot2 * irot % mod offset = s << (h - le + 2) for i in range(p): a0 = a[i + offset] a1 = a[i + offset + p] a2 = a[i + offset + p * 2] a3 = a[i + offset + p * 3] a2na3iimag = (a2 - a3) * iimag % mod a[i + offset] = (a0 + a1 + a2 + a3) % mod a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % mod a[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % mod a[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % mod irot *= irate3[idx][((~s & -~s) - 1).bit_length()] irot %= mod le -= 2 def convolution_naive(a, b, mod = 998244353): res = [0] * (len(a) + len(b) - 1) for i in range(len(a)): for j in range(len(b)): res[i + j] = (res[i + j] + a[i] * b[j] % mod) % mod return res def convolution_ntt(a, b, mod = 998244353): s = a[:] t = b[:] n = len(s) m = len(t) if min(n, m) <= 60: return convolution_naive(s, t, mod) le = 1 while le < n + m - 1: le *= 2 s += [0] * (le - n) t += [0] * (le - m) ntt(s, 0, mod) ntt(t, 0, mod) for i in range(le): s[i] = s[i] * t[i] % mod ntt(s, 1, mod) s = s[:n + m - 1] return s def convolution_garner(f, g, mod): if min(len(f), len(g)) <= 60: return convolution_naive(f, g, mod) MOD = [167772161, 469762049, 754974721] flag = 0 if (mod - 1) * (mod - 1) * min(len(f), len(g)) >= 167772161 * 469762049 * 754974721: MOD += [880803841, 998244353] flag = 1 H = [] for i in range(len(MOD)): H.append(convolution_ntt(f, g, MOD[i])) h = [] for i in range(len(H[0])): Rem = [H[0][i], H[1][i], H[2][i]] if flag: Rem += [H[3][i], H[4][i]] h.append(Garner(Rem, MOD, mod) % mod) return h def convolution(f, g, mod = 998244353): if NTT_info(mod)[1] == -1: return convolution_garner(f, g, mod) return convolution_ntt(f, g, mod) def fps_inv(f, deg = -1, mod = 998244353): assert (f[0] != 0) if deg == -1: deg = len(f) n = len(f) # ntt_prime if NTT_info(mod)[2] != -1: g = [mod_inv(f[0], mod)] + [0 for _ in range(deg - 1)] le = 1 while le < deg: a = [0 for _ in range(2 * le)] b = [0 for _ in range(2 * le)] for i in range(min(n, 2 * le)): a[i] = f[i] for i in range(le): b[i] = g[i] ntt(a, 0, mod) ntt(b, 0, mod) for i in range(2 * le): a[i] = a[i] * b[i] % mod ntt(a, 1, mod) for i in range(le): a[i] = 0 ntt(a, 0, mod) for i in range(2 * le): a[i] = a[i] * b[i] % mod ntt(a, 1, mod) for j in range(le, min(deg, 2 * le)): g[j] = (mod - a[j]) % mod le *= 2 return g # not ntt prime # doubling else: g = [0 for _ in range(deg)] g[0] = mod_inv(f[0], mod) gg = [] le = 1 while le < deg: gg = convolution(g[:le], g[:le]) ff = f[:min(2 * le, n)] gg = convolution(ff, gg) for i in range(min(deg, 2 * le)): g[i] = (g[i] + g[i] - gg[i]) % mod le *= 2 return g[:deg] def Bostan_Mori_ntt(P, Q, N, mod = 998244353): f, g = P[:], Q[:] le = 2 while le < 2 * len(Q): le *= 2 while len(f) < le: f.append(0) while len(g) < le: g.append(0) while N: ntt(f, 0, mod), ntt(g, 0, mod) for i in range(le // 2): f[2 * i], f[2 * i + 1] = f[2 * i] * g[2 * i + 1] % mod, f[2 * i + 1] * g[2 * i] % mod g[2 * i] = g[2 * i] * g[2 * i + 1] % mod g[2 * i + 1] = g[2 * i] ntt(f, 1, mod), ntt(g, 1, mod) r = N % 2 for i in range(le // 2): g[i] = g[2 * i] if i > 0: g[2 * i] = 0 g[2 * i + 1] = 0 f[i] = (f[2 * i + r]) % mod if i > 0: f[2 * i] = 0 f[2 * i + 1] = 0 N //= 2 return f[0] def Bostan_Mori(P, Q, N, mod = 998244353): if NTT_info(mod)[2] != -1: return Bostan_Mori_ntt(P, Q, N) f, g = P[:], Q[:] while N: g2 = g[:] for i in range(1, len(g2), 2): g2[i] = (mod - g2[i]) % mod f = convolution(f, g2) g = convolution(g, g2) S = [0 for _ in range((len(f) + 1) // 2)] T = [0 for _ in range((len(g) + 1) // 2)] r = N % 2 for i in range(r, len(f), 2): S[i // 2] = (S[i // 2] + f[i]) % mod for i in range(0, len(g), 2): T[i // 2] = g[i] f, g = S[:], T[:] N //= 2 return f[0] N, K = map(int, input().split()) if N == 1: print(0) exit() f = [1 for _ in range(K)] g = [0 for _ in range(2 * K + 3)] g[0] += 1 g[1] -= 3 g[2] += 2 g[K + 1] += 1 g[K + 2] -= 2 g[2 * K + 2] += 1 ans = (Bostan_Mori_ntt(f, g, N - 2) + 1) % mod print(ans)