結果
問題 |
No.1068 #いろいろな色 / Red and Blue and more various colors (Hard)
|
ユーザー |
![]() |
提出日時 | 2025-03-31 17:24:27 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 3,435 ms / 3,500 ms |
コード長 | 2,411 bytes |
コンパイル時間 | 140 ms |
コンパイル使用メモリ | 82,396 KB |
実行使用メモリ | 158,180 KB |
最終ジャッジ日時 | 2025-03-31 17:25:52 |
合計ジャッジ時間 | 65,168 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 29 |
ソースコード
import sys MOD = 998244353 G = 3 def ntt(a, invert=False): n = len(a) j = 0 for i in range(1, n): bit = n >> 1 while j >= bit: j -= bit bit >>= 1 j += bit if i < j: a[i], a[j] = a[j], a[i] l = 2 while l <= n: omega = pow(G, (MOD-1)//l, MOD) if invert: omega = pow(omega, MOD-2, MOD) for i in range(0, n, l): w = 1 for j in range(l//2): u = a[i+j] v = a[i+j + l//2] * w % MOD a[i+j] = (u + v) % MOD a[i+j + l//2] = (u - v) % MOD w = w * omega % MOD l <<= 1 if invert: inv = pow(n, MOD-2, MOD) for i in range(n): a[i] = a[i] * inv % MOD def multiply_ntt(a, b): len_ab = len(a) + len(b) - 1 n = 1 while n < len_ab: n <<= 1 a += [0] * (n - len(a)) b += [0] * (n - len(b)) ntt(a) ntt(b) c = [(a[i] * b[i]) % MOD for i in range(n)] ntt(c, invert=True) del c[len_ab:] return c def product_polynomials(c_list): if len(c_list) == 0: return [1] if len(c_list) == 1: return [1, c_list[0] % MOD] mid = len(c_list) // 2 left = product_polynomials(c_list[:mid]) right = product_polynomials(c_list[mid:]) return multiply_ntt(left, right) def main(): input = sys.stdin.read().split() ptr = 0 N = int(input[ptr]) ptr += 1 Q = int(input[ptr]) ptr += 1 A = list(map(int, input[ptr:ptr+N])) ptr += N B_list = list(map(int, input[ptr:ptr+Q])) ptr += Q # Process A and split into zero and non-zero groups K = 0 c_list = [] for a in A: if a == 1: K += 1 else: c = (a - 1) % MOD c_list.append(c) M = len(c_list) # Precompute the product polynomials if M == 0: poly = [1] else: poly = product_polynomials(c_list) poly = [x % MOD for x in poly] # Answer queries for B in B_list: if B < K: print(0) else: d = N - B if d < 0 or d > M: print(0) else: if d >= len(poly): print(0) else: print(poly[d] % MOD) if __name__ == "__main__": main()