結果

問題 No.1066 #いろいろな色 / Red and Blue and more various colors (Easy)
ユーザー SalmonizeSalmonize
提出日時 2020-05-30 00:38:16
言語 Python3
(3.12.2 + numpy 1.26.4 + scipy 1.12.0)
結果
AC  
実行時間 792 ms / 2,000 ms
コード長 3,554 bytes
コンパイル時間 283 ms
コンパイル使用メモリ 13,184 KB
実行使用メモリ 12,544 KB
最終ジャッジ日時 2024-11-06 10:04:59
合計ジャッジ時間 8,337 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 32 ms
11,136 KB
testcase_01 AC 31 ms
11,136 KB
testcase_02 AC 32 ms
11,136 KB
testcase_03 AC 31 ms
11,136 KB
testcase_04 AC 32 ms
11,136 KB
testcase_05 AC 31 ms
11,008 KB
testcase_06 AC 32 ms
11,136 KB
testcase_07 AC 31 ms
11,136 KB
testcase_08 AC 405 ms
11,776 KB
testcase_09 AC 35 ms
11,008 KB
testcase_10 AC 784 ms
12,544 KB
testcase_11 AC 357 ms
11,776 KB
testcase_12 AC 352 ms
11,904 KB
testcase_13 AC 368 ms
11,648 KB
testcase_14 AC 364 ms
11,776 KB
testcase_15 AC 748 ms
12,544 KB
testcase_16 AC 193 ms
11,264 KB
testcase_17 AC 206 ms
11,264 KB
testcase_18 AC 726 ms
12,288 KB
testcase_19 AC 364 ms
11,776 KB
testcase_20 AC 35 ms
11,136 KB
testcase_21 AC 441 ms
12,032 KB
testcase_22 AC 36 ms
11,008 KB
testcase_23 AC 29 ms
11,136 KB
testcase_24 AC 31 ms
11,136 KB
testcase_25 AC 792 ms
12,416 KB
testcase_26 AC 586 ms
12,288 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

readline = sys.stdin.readline

ns = lambda: readline().rstrip()
ni = lambda: int(readline().rstrip())
nm = lambda: map(int, readline().split())
nl = lambda: list(map(int, readline().split()))


def modinv(x, mod):
    a, b = x, mod
    u, v = 1, 0
    while b:
        t = a // b
        a -= t * b; a, b = b, a
        u -= t * v; u, v = v, u
    return u % mod


def _garner(xs, mods):
    M = len(xs)
    coeffs = [1] * M
    constants = [0] * M
    for i in range(M - 1):
        mod_i = mods[i]
        v = (xs[i] - constants[i]) * modinv(coeffs[i], mod_i) % mod_i
        for j in range(i + 1, M):
            mod_j = mods[j]
            constants[j] = (constants[j] + coeffs[j] * v) % mod_j
            coeffs[j] = (coeffs[j] * mod_i) % mod_j

    return constants[-1]


def bit_reverse(d):
    n = len(d)
    ns = n >> 1
    nss = ns >> 1
    ns1 = ns + 1
    i = 0
    for j in range(0, ns, 2):
        if j < i:
            d[i], d[j] = d[j], d[i]
            d[i + ns1], d[j + ns1] = d[j + ns1], d[i + ns1]
        d[i + 1], d[j + ns] = d[j + ns], d[i + 1]
        k = nss
        i ^= k
        while k > i:
            k >>= 1
            i ^= k
    return d


class NTT:
    def __init__(self, mod, primitive_root):
        self.mod = mod
        self.root = primitive_root

    def _ntt(self, a, sign):
        n = len(a)
        mod, g = self.mod, self.root
        tmp = (mod - 1) * modinv(n, mod) % mod  # -1/n
        h = pow(g, tmp, mod)  # ^n√g
        if sign < 0:
            h = modinv(h, mod)

        a = bit_reverse(a)

        m = 1
        while m < n:
            m2 = m << 1
            _base = pow(h, n // m2, mod)
            _w = 1
            for x in range(m):
                for s in range(x, n, m2):
                    u = a[s]
                    d = a[s + m] * _w % mod
                    a[s] = (u + d) % mod
                    a[s + m] = (u - d) % mod
                _w = _w * _base % mod
            m <<= 1
        return a

    def ntt(self, a):
        return self._ntt(a, 1)

    def intt(self, a):
        mod = self.mod
        n = len(a)
        n_inv = modinv(n, mod)
        a = self._ntt(a, -1)
        for i in range(n):
            a[i] = a[i] * n_inv % mod
        return a

    def convolution(self, a, b):
        mod = self.mod
        ret_size = len(a) + len(b) - 1
        n = 1 << (ret_size - 1).bit_length()
        _a = a + [0] * (n - len(a))
        _b = b + [0] * (n - len(b))
        _a = self.ntt(_a)
        _b = self.ntt(_b)
        _a = [x * y % mod for x, y in zip(_a, _b)]
        _a = self.intt(_a)
        _a = _a[:ret_size]
        return _a


def convolution_ntt(a, b, mod):
    a = [x % mod for x in a]
    b = [x % mod for x in b]

    mods = (167772161, 469762049, 1224736769, mod)

    ntt1 = NTT(mods[0], 3)
    ntt2 = NTT(mods[1], 3)
    ntt3 = NTT(mods[2], 3)

    x1 = ntt1.convolution(a, b)
    x2 = ntt2.convolution(a, b)
    x3 = ntt3.convolution(a, b)

    n = len(x1)
    ret = [0] * n
    for i in range(n):
        xs = [x1[i], x2[i], x3[i], 0]
        ret[i] = _garner(xs, mods)

    return ret



def solve():
    mod = 998244353
    n, q = nm()
    a = nl()
    ntt = NTT(mod, 5)

    def recurr(l, r):
        if l + 1 == r:
            return [a[l]-1, 1]
        res1 = recurr(l, (l+r)//2)
        res2 = recurr((l+r)//2, r)
        res = ntt.convolution(res1, res2)
        for i in range(len(res)):
            res[i] %= mod
        return res

    p = recurr(0, n)
    b = nl()
    for x in b:
        print(p[x])

    return

solve()
0