結果

問題 No.1068 #いろいろな色 / Red and Blue and more various colors (Hard)
ユーザー lam6er
提出日時 2025-03-31 17:24:27
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 3,435 ms / 3,500 ms
コード長 2,411 bytes
コンパイル時間 140 ms
コンパイル使用メモリ 82,396 KB
実行使用メモリ 158,180 KB
最終ジャッジ日時 2025-03-31 17:25:52
合計ジャッジ時間 65,168 ms
ジャッジサーバーID
(参考情報)
judge5 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 29
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
MOD = 998244353
G = 3

def ntt(a, invert=False):
    n = len(a)
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j >= bit:
            j -= bit
            bit >>= 1
        j += bit
        if i < j:
            a[i], a[j] = a[j], a[i]
    l = 2
    while l <= n:
        omega = pow(G, (MOD-1)//l, MOD)
        if invert:
            omega = pow(omega, MOD-2, MOD)
        for i in range(0, n, l):
            w = 1
            for j in range(l//2):
                u = a[i+j]
                v = a[i+j + l//2] * w % MOD
                a[i+j] = (u + v) % MOD
                a[i+j + l//2] = (u - v) % MOD
                w = w * omega % MOD
        l <<= 1
    if invert:
        inv = pow(n, MOD-2, MOD)
        for i in range(n):
            a[i] = a[i] * inv % MOD

def multiply_ntt(a, b):
    len_ab = len(a) + len(b) - 1
    n = 1
    while n < len_ab:
        n <<= 1
    a += [0] * (n - len(a))
    b += [0] * (n - len(b))
    ntt(a)
    ntt(b)
    c = [(a[i] * b[i]) % MOD for i in range(n)]
    ntt(c, invert=True)
    del c[len_ab:]
    return c

def product_polynomials(c_list):
    if len(c_list) == 0:
        return [1]
    if len(c_list) == 1:
        return [1, c_list[0] % MOD]
    mid = len(c_list) // 2
    left = product_polynomials(c_list[:mid])
    right = product_polynomials(c_list[mid:])
    return multiply_ntt(left, right)

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr += 1
    Q = int(input[ptr])
    ptr += 1
    
    A = list(map(int, input[ptr:ptr+N]))
    ptr += N
    B_list = list(map(int, input[ptr:ptr+Q]))
    ptr += Q
    
    # Process A and split into zero and non-zero groups
    K = 0
    c_list = []
    for a in A:
        if a == 1:
            K += 1
        else:
            c = (a - 1) % MOD
            c_list.append(c)
    M = len(c_list)
    
    # Precompute the product polynomials
    if M == 0:
        poly = [1]
    else:
        poly = product_polynomials(c_list)
        poly = [x % MOD for x in poly]
    
    # Answer queries
    for B in B_list:
        if B < K:
            print(0)
        else:
            d = N - B
            if d < 0 or d > M:
                print(0)
            else:
                if d >= len(poly):
                    print(0)
                else:
                    print(poly[d] % MOD)
    
if __name__ == "__main__":
    main()
0