結果
| 問題 |
No.1068 #いろいろな色 / Red and Blue and more various colors (Hard)
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-31 17:24:27 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 3,435 ms / 3,500 ms |
| コード長 | 2,411 bytes |
| コンパイル時間 | 140 ms |
| コンパイル使用メモリ | 82,396 KB |
| 実行使用メモリ | 158,180 KB |
| 最終ジャッジ日時 | 2025-03-31 17:25:52 |
| 合計ジャッジ時間 | 65,168 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 29 |
ソースコード
import sys
MOD = 998244353
G = 3
def ntt(a, invert=False):
n = len(a)
j = 0
for i in range(1, n):
bit = n >> 1
while j >= bit:
j -= bit
bit >>= 1
j += bit
if i < j:
a[i], a[j] = a[j], a[i]
l = 2
while l <= n:
omega = pow(G, (MOD-1)//l, MOD)
if invert:
omega = pow(omega, MOD-2, MOD)
for i in range(0, n, l):
w = 1
for j in range(l//2):
u = a[i+j]
v = a[i+j + l//2] * w % MOD
a[i+j] = (u + v) % MOD
a[i+j + l//2] = (u - v) % MOD
w = w * omega % MOD
l <<= 1
if invert:
inv = pow(n, MOD-2, MOD)
for i in range(n):
a[i] = a[i] * inv % MOD
def multiply_ntt(a, b):
len_ab = len(a) + len(b) - 1
n = 1
while n < len_ab:
n <<= 1
a += [0] * (n - len(a))
b += [0] * (n - len(b))
ntt(a)
ntt(b)
c = [(a[i] * b[i]) % MOD for i in range(n)]
ntt(c, invert=True)
del c[len_ab:]
return c
def product_polynomials(c_list):
if len(c_list) == 0:
return [1]
if len(c_list) == 1:
return [1, c_list[0] % MOD]
mid = len(c_list) // 2
left = product_polynomials(c_list[:mid])
right = product_polynomials(c_list[mid:])
return multiply_ntt(left, right)
def main():
input = sys.stdin.read().split()
ptr = 0
N = int(input[ptr])
ptr += 1
Q = int(input[ptr])
ptr += 1
A = list(map(int, input[ptr:ptr+N]))
ptr += N
B_list = list(map(int, input[ptr:ptr+Q]))
ptr += Q
# Process A and split into zero and non-zero groups
K = 0
c_list = []
for a in A:
if a == 1:
K += 1
else:
c = (a - 1) % MOD
c_list.append(c)
M = len(c_list)
# Precompute the product polynomials
if M == 0:
poly = [1]
else:
poly = product_polynomials(c_list)
poly = [x % MOD for x in poly]
# Answer queries
for B in B_list:
if B < K:
print(0)
else:
d = N - B
if d < 0 or d > M:
print(0)
else:
if d >= len(poly):
print(0)
else:
print(poly[d] % MOD)
if __name__ == "__main__":
main()
lam6er