結果
| 問題 |
No.1068 #いろいろな色 / Red and Blue and more various colors (Hard)
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2020-05-29 22:49:17 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 3,137 bytes |
| コンパイル時間 | 167 ms |
| コンパイル使用メモリ | 82,404 KB |
| 実行使用メモリ | 270,696 KB |
| 最終ジャッジ日時 | 2024-11-06 06:56:15 |
| 合計ジャッジ時間 | 8,117 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 10 TLE * 1 -- * 18 |
ソースコード
class NumberTheroemTransform:
def __init__(self, mod, pr):
self.mod = mod
self.pr = pr
self.M = 2
self.w = [1]
self.y = [1]
def setwy(self, M):
if M <= self.M: return
self.w += [0] * ((M - self.M) // 2)
self.y += [0] * ((M - self.M) // 2)
self.M = M
z = pow(self.pr, (self.mod - 1) // self.M, self.mod)
x = pow(z, self.mod - 2, self.mod)
j = M // 4
while j:
self.w[j] = z
z = z * z % self.mod
self.y[j] = x
x = x * x % self.mod
j //= 2
self.y[0] = 1
self.w[0] = 1
j = self.M // 2
js = 2
while js < j:
z = self.w[js]
x = self.y[js]
for k2 in range(js):
self.w[k2 + js] = self.w[k2] * z % self.mod
self.y[k2 + js] = self.y[k2] * x % self.mod
js *= 2
def fft(self, a):
mod = self.mod
self.setwy(len(a))
u = 1
v = len(a) >> 1
while v:
for j in range(v):
a[j], a[j + v] = a[j] + a[j + v], a[j] - a[j + v]
if a[j] >= mod:
a[j] -= mod
if a[j + v] < 0:
a[j + v] += mod
for jh in range(1, u):
wj = self.w[jh]
js = jh * v * 2
je = js + v
for j in range(js, je):
ajv = wj * a[j + v] % mod
a[j + v] = a[j] - ajv
if a[j + v] < 0:
a[j + v] += mod
a[j] = a[j] + ajv
if a[j] >= mod:
a[j] -= mod
u *= 2
v >>= 1
def ifft(self, a):
mod = self.mod
self.setwy(len(a))
u = len(a) >> 1
v = 1
while u:
for j in range(v):
a[j], a[j + v] = a[j] + a[j + v], a[j] - a[j + v]
if a[j] >= mod:
a[j] -= mod
if a[j + v] < 0:
a[j + v] += mod
for jh in range(1, u):
wj = self.y[jh]
js = jh * v * 2
je = js + v
for j in range(js, je):
ajv = a[j] - a[j + v]
if ajv < 0:
ajv += mod
a[j] = a[j] + a[j + v]
if a[j] >= mod:
a[j] -= mod
a[j + v] = wj * ajv % mod
u >>= 1
v *= 2
def multiply(self, s, t):
if len(t) == 1:
return s
sl, tl = len(s), len(t)
L = sl + tl - 1
if L < 8:
u = [0] * L
for i in range(sl):
for j in range(tl):
u[i+j] += s[i] * t[j]
for i in range(L):
u[i] %= self.mod
return u
M = 2 ** (L - 1).bit_length()
s += [0] * (M - sl)
t += [0] * (M - tl)
self.fft(s)
self.fft(t)
for i in range(M):
s[i] = s[i] * t[i] % self.mod
self.ifft(s)
invk = pow(M, self.mod - 2, self.mod)
for i in range(L):
s[i] = s[i] * invk % self.mod
del s[L:]
return s
ntt = NumberTheroemTransform(998244353, 3)
import sys
read = sys.stdin.buffer.read
readline = sys.stdin.buffer.readline
N, Q = map(int, readline().split())
AB = list(map(int, read().split()))
A = AB[:N]
B = AB[N:]
M = 2 ** (N - 1).bit_length()
seg = [ [1] for i in range(2 * M)]
for i in range(N):
seg[M + i] = [A[i] - 1, 1]
for i in range(M - 1, 0, -1):
seg[i] = ntt.multiply(seg[2 * i], seg[2 * i + 1])
for q in B:
print(seg[1][q])