結果

問題 No.1068 #いろいろな色 / Red and Blue and more various colors (Hard)
ユーザー lam6er
提出日時 2025-03-26 15:46:24
言語 PyPy3
(7.3.15)
結果
RE  
実行時間 -
コード長 2,537 bytes
コンパイル時間 257 ms
コンパイル使用メモリ 82,280 KB
実行使用メモリ 146,288 KB
最終ジャッジ日時 2025-03-26 15:46:49
合計ジャッジ時間 6,452 ms
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 1 RE * 28
権限があれば一括ダウンロードができます

ソースコード

diff #

MOD = 998244353
ROOT = 3

def ntt(a, invert=False):
    n = len(a)
    j = 0
    for i in range(1, n-1):
        bit = n >> 1
        while j >= bit:
            j -= bit
            bit >>= 1
        j += bit
        if i < j:
            a[i], a[j] = a[j], a[i]
    log_n = (n).bit_length() - 1
    root = pow(ROOT, (MOD - 1) // n, MOD) if not invert else pow(ROOT, (MOD - 1) - (MOD - 1) // n, MOD)
    roots = [1] * (n // 2)
    for i in range(1, n // 2):
        roots[i] = roots[i-1] * root % MOD
    for L in range(2, n + 1, 2):
        L_half = L // 2
        step = n // L
        for i in range(0, n, L):
            for j in range(L_half):
                idx_e = i + j
                idx_o = i + j + L_half
                even = a[idx_e]
                odd = a[idx_o] * roots[j * step] % MOD
                a[idx_e] = (even + odd) % MOD
                a[idx_o] = (even - odd) % MOD
                if a[idx_o] < 0:
                    a[idx_o] += MOD
    if invert:
        inv_n = pow(n, MOD-2, MOD)
        for i in range(n):
            a[i] = a[i] * inv_n % MOD
    return a

def multiply_ntt(a, b):
    len_a = len(a)
    len_b = len(b)
    if len_a == 0 or len_b == 0:
        return []
    max_len = len_a + len_b - 1
    n = 1
    while n < max_len:
        n <<= 1
    a_ntt = a + [0] * (n - len_a)
    b_ntt = b + [0] * (n - len_b)
    a_ntt = ntt(a_ntt)
    b_ntt = ntt(b_ntt)
    c_ntt = [(x * y) % MOD for x, y in zip(a_ntt, b_ntt)]
    c = ntt(c_ntt, invert=True)
    return c[:max_len]

def compute_polynomial(d_list):
    if len(d_list) == 0:
        return [1]
    if len(d_list) == 1:
        return [1, d_list[0]]
    mid = len(d_list) // 2
    left = compute_polynomial(d_list[:mid])
    right = compute_polynomial(d_list[mid:])
    return multiply_ntt(left, right)

def main():
    import sys
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr += 1
    Q = int(input[ptr])
    ptr += 1
    A = list(map(int, input[ptr:ptr+N]))
    ptr += N
    B_list = list(map(int, input[ptr:ptr+Q]))
    d_list = [(a - 1) % MOD for a in A]
    non_zero = [d for d in d_list if d != 0]
    M = len(non_zero)
    if M == 0:
        product_coeffs = [1]
    else:
        product_coeffs = compute_polynomial(non_zero)
    for B in B_list:
        K = N - B
        if K < 0 or K > M:
            print(0)
        else:
            if K < len(product_coeffs):
                print(product_coeffs[K] % MOD)
            else:
                print(0)

if __name__ == "__main__":
    main()
0