結果
問題 |
No.3190 Scoring
|
ユーザー |
|
提出日時 | 2025-06-16 10:51:31 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 5,562 ms / 10,000 ms |
コード長 | 3,323 bytes |
コンパイル時間 | 399 ms |
コンパイル使用メモリ | 82,056 KB |
実行使用メモリ | 454,392 KB |
最終ジャッジ日時 | 2025-06-20 20:54:23 |
合計ジャッジ時間 | 114,060 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 32 |
ソースコード
mod = 998244353 frac_max = 5 * 10 ** 5 frac = [1] * (frac_max + 1) for i in range(2, frac_max + 1): frac[i] = frac[i - 1] * i % mod frac_inv = [1] * (frac_max + 1) frac_inv[frac_max] = pow(frac[frac_max], mod - 2, mod) for i in range(2, frac_max + 1)[::-1]: frac_inv[i - 1] = frac_inv[i] * i % mod def fft(a): n = len(a) m = n while m >= 2: m2 = m // 2 w = pow(3, (mod - 1) // m, mod) for i in range(0, n, m): wj = 1 for j in range(m2): u = a[i + j] v = a[i + j + m2] a[i + j] = (u + v) % mod a[i + j + m2] = (u - v) * wj % mod wj = wj * w % mod m //= 2 def ifft(a): n = len(a) m = 2 while m <= n: m2 = m // 2 w = pow(3, mod - 1 - (mod - 1) // m, mod) for i in range(0, n, m): wj = 1 for j in range(m2): u = a[i + j] v = a[i + j + m2] * wj a[i + j] = (u + v) % mod a[i + j + m2] = (u - v) % mod wj = wj * w % mod m *= 2 def convolution(a, b): n = len(a) + len(b) - 1 m = 1 << (n - 1).bit_length() a = a + [0] * (m - len(a)) b = b + [0] * (m - len(b)) fft(a) fft(b) for i in range(m): a[i] = a[i] * b[i] % mod ifft(a) inv_m = pow(m, mod - 2, mod) return [x * inv_m % mod for x in a[:n]] def add(a, b): size = max(len(a), len(b)) c = [0] * size for i, x in enumerate(a): c[i] += x for i, x in enumerate(b): c[i] = (c[i] + x) % mod while c and c[-1] == 0: c.pop() return c def inv(a, k): assert a[0] == 1 a = a + [0] * (k - len(a)) b = [1] n = 1 while n < k: ab = convolution(a[:2 * n], b) d = [mod - ab[i] for i in range(n, 2 * n)] db = convolution(d, b) b = b + db[:n] n *= 2 return b[:k] n, s, m = map(int, input().split()) m2 = (m + 1) // 2 fm_tmp = convolution( [frac_inv[i] for i in range(m2, m + 1)], [frac_inv[i] if i % 2 == 0 else mod - frac_inv[i] for i in range(0, m + 1)] ) fm = [fm_tmp[i - m2] * frac[m] * frac_inv[m - i] % mod for i in range(m2, m + 1)] p = [frac[s - i + n - 1] * frac[s] * frac_inv[s - i] * frac_inv[s + n - 1] % mod for i in range(1, s + 1)] prod = [([1], [mod - i]) for i in p] i = 0 while i + 1 < len(prod): a1, b1 = prod[i] a2, b2 = prod[i + 1] size = len(b1) + len(b2) - 1 size_n = 1 << (size - 1).bit_length() size_n_inv = pow(size_n, mod - 2, mod) b1b2 = add(b1, b2) a1a2 = add(a1, a2) a1 += [0] * (size_n - len(a1)) b1 += [0] * (size_n - len(b1)) a2 += [0] * (size_n - len(a2)) b2 += [0] * (size_n - len(b2)) fft(a1) fft(b1) fft(a2) fft(b2) b3_tmp = [i * j % mod for i, j in zip(b1, b2)] a3_tmp = [(i * l + j * k) % mod for i, j, k, l in zip(a1, b1, a2, b2)] ifft(b3_tmp) ifft(a3_tmp) b3 = add([0] + [i * size_n_inv for i in b3_tmp[:size]], b1b2) a3 = add([0] + [i * size_n_inv for i in a3_tmp[:size]], a1a2) prod.append((a3, b3)) i += 2 a, b = prod[-1] ab = convolution(a, inv([1] + b, m + 1)) ans = 0 for i in range(m2, m + 1): ans = (ans + ab[i] * fm[i - m2]) % mod print(ans * n % mod)