結果

問題 No.1066 #いろいろな色 / Red and Blue and more various colors (Easy)
ユーザー SalmonizeSalmonize
提出日時 2020-05-30 00:38:16
言語 Python3
(3.13.1 + numpy 2.2.1 + scipy 1.14.1)
結果
AC  
実行時間 792 ms / 2,000 ms
コード長 3,554 bytes
コンパイル時間 283 ms
コンパイル使用メモリ 13,184 KB
実行使用メモリ 12,544 KB
最終ジャッジ日時 2024-11-06 10:04:59
合計ジャッジ時間 8,337 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 24
権限があれば一括ダウンロードができます

ソースコード

diff #
プレゼンテーションモードにする

import sys
readline = sys.stdin.readline
ns = lambda: readline().rstrip()
ni = lambda: int(readline().rstrip())
nm = lambda: map(int, readline().split())
nl = lambda: list(map(int, readline().split()))
def modinv(x, mod):
a, b = x, mod
u, v = 1, 0
while b:
t = a // b
a -= t * b; a, b = b, a
u -= t * v; u, v = v, u
return u % mod
def _garner(xs, mods):
M = len(xs)
coeffs = [1] * M
constants = [0] * M
for i in range(M - 1):
mod_i = mods[i]
v = (xs[i] - constants[i]) * modinv(coeffs[i], mod_i) % mod_i
for j in range(i + 1, M):
mod_j = mods[j]
constants[j] = (constants[j] + coeffs[j] * v) % mod_j
coeffs[j] = (coeffs[j] * mod_i) % mod_j
return constants[-1]
def bit_reverse(d):
n = len(d)
ns = n >> 1
nss = ns >> 1
ns1 = ns + 1
i = 0
for j in range(0, ns, 2):
if j < i:
d[i], d[j] = d[j], d[i]
d[i + ns1], d[j + ns1] = d[j + ns1], d[i + ns1]
d[i + 1], d[j + ns] = d[j + ns], d[i + 1]
k = nss
i ^= k
while k > i:
k >>= 1
i ^= k
return d
class NTT:
def __init__(self, mod, primitive_root):
self.mod = mod
self.root = primitive_root
def _ntt(self, a, sign):
n = len(a)
mod, g = self.mod, self.root
tmp = (mod - 1) * modinv(n, mod) % mod # -1/n
h = pow(g, tmp, mod) # ^n√g
if sign < 0:
h = modinv(h, mod)
a = bit_reverse(a)
m = 1
while m < n:
m2 = m << 1
_base = pow(h, n // m2, mod)
_w = 1
for x in range(m):
for s in range(x, n, m2):
u = a[s]
d = a[s + m] * _w % mod
a[s] = (u + d) % mod
a[s + m] = (u - d) % mod
_w = _w * _base % mod
m <<= 1
return a
def ntt(self, a):
return self._ntt(a, 1)
def intt(self, a):
mod = self.mod
n = len(a)
n_inv = modinv(n, mod)
a = self._ntt(a, -1)
for i in range(n):
a[i] = a[i] * n_inv % mod
return a
def convolution(self, a, b):
mod = self.mod
ret_size = len(a) + len(b) - 1
n = 1 << (ret_size - 1).bit_length()
_a = a + [0] * (n - len(a))
_b = b + [0] * (n - len(b))
_a = self.ntt(_a)
_b = self.ntt(_b)
_a = [x * y % mod for x, y in zip(_a, _b)]
_a = self.intt(_a)
_a = _a[:ret_size]
return _a
def convolution_ntt(a, b, mod):
a = [x % mod for x in a]
b = [x % mod for x in b]
mods = (167772161, 469762049, 1224736769, mod)
ntt1 = NTT(mods[0], 3)
ntt2 = NTT(mods[1], 3)
ntt3 = NTT(mods[2], 3)
x1 = ntt1.convolution(a, b)
x2 = ntt2.convolution(a, b)
x3 = ntt3.convolution(a, b)
n = len(x1)
ret = [0] * n
for i in range(n):
xs = [x1[i], x2[i], x3[i], 0]
ret[i] = _garner(xs, mods)
return ret
def solve():
mod = 998244353
n, q = nm()
a = nl()
ntt = NTT(mod, 5)
def recurr(l, r):
if l + 1 == r:
return [a[l]-1, 1]
res1 = recurr(l, (l+r)//2)
res2 = recurr((l+r)//2, r)
res = ntt.convolution(res1, res2)
for i in range(len(res)):
res[i] %= mod
return res
p = recurr(0, n)
b = nl()
for x in b:
print(p[x])
return
solve()
הההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההה
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
0