結果

問題 No.1068 #いろいろな色 / Red and Blue and more various colors (Hard)
ユーザー NyaanNyaanNyaanNyaan
提出日時 2020-05-29 22:49:17
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,137 bytes
コンパイル時間 167 ms
コンパイル使用メモリ 82,404 KB
実行使用メモリ 270,696 KB
最終ジャッジ日時 2024-11-06 06:56:15
合計ジャッジ時間 8,117 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 38 ms
60,520 KB
testcase_01 AC 36 ms
52,944 KB
testcase_02 AC 39 ms
54,304 KB
testcase_03 AC 224 ms
80,252 KB
testcase_04 AC 207 ms
79,796 KB
testcase_05 AC 211 ms
79,412 KB
testcase_06 AC 191 ms
78,876 KB
testcase_07 AC 177 ms
79,012 KB
testcase_08 AC 197 ms
79,284 KB
testcase_09 AC 202 ms
79,412 KB
testcase_10 AC 166 ms
78,488 KB
testcase_11 AC 179 ms
79,264 KB
testcase_12 AC 166 ms
78,932 KB
testcase_13 TLE -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
testcase_30 -- -
testcase_31 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

class NumberTheroemTransform:
  def __init__(self, mod, pr):
    self.mod = mod
    self.pr = pr
    self.M = 2
    self.w = [1]
    self.y = [1]

  def setwy(self, M):
    if M <= self.M: return
    self.w += [0] * ((M - self.M) // 2)
    self.y += [0] * ((M - self.M) // 2)
    self.M = M
    z = pow(self.pr, (self.mod - 1) // self.M, self.mod)
    x = pow(z, self.mod - 2, self.mod)
    j = M // 4
    while j:
      self.w[j] = z
      z = z * z % self.mod
      self.y[j] = x
      x = x * x % self.mod
      j //= 2
    self.y[0] = 1
    self.w[0] = 1
    j = self.M // 2
    js = 2
    while js < j:
      z = self.w[js]
      x = self.y[js]
      for k2 in range(js):
        self.w[k2 + js] = self.w[k2] * z % self.mod
        self.y[k2 + js] = self.y[k2] * x % self.mod
      js *= 2

  def fft(self, a):
    mod = self.mod
    self.setwy(len(a))
    u = 1
    v = len(a) >> 1
    while v:
      for j in range(v):
        a[j], a[j + v] = a[j] + a[j + v], a[j] - a[j + v]
        if a[j] >= mod:
          a[j] -= mod
        if a[j + v] < 0:
          a[j + v] += mod
      for jh in range(1, u):
        wj = self.w[jh]
        js = jh * v * 2
        je = js + v
        for j in range(js, je):
          ajv = wj * a[j + v] % mod
          a[j + v] = a[j] - ajv
          if a[j + v] < 0:
            a[j + v] += mod
          a[j] = a[j] + ajv
          if a[j] >= mod:
            a[j] -= mod
      u *= 2
      v >>= 1

  def ifft(self, a):
    mod = self.mod
    self.setwy(len(a))
    u = len(a) >> 1
    v = 1
    while u:
      for j in range(v):
        a[j], a[j + v] = a[j] + a[j + v], a[j] - a[j + v]
        if a[j] >= mod:
          a[j] -= mod
        if a[j + v] < 0:
          a[j + v] += mod
      for jh in range(1, u):
        wj = self.y[jh]
        js = jh * v * 2
        je = js + v
        for j in range(js, je):
          ajv = a[j] - a[j + v]
          if ajv < 0:
            ajv += mod
          a[j] = a[j] + a[j + v]
          if a[j] >= mod:
            a[j] -= mod
          a[j + v] = wj * ajv % mod
      u >>= 1
      v *= 2

  def multiply(self, s, t):
    if len(t) == 1:
      return s
    sl, tl = len(s), len(t)
    L = sl + tl - 1
    if L < 8:
      u = [0] * L
      for i in range(sl):
        for j in range(tl):
          u[i+j] += s[i] * t[j]
      for i in range(L):
        u[i] %= self.mod
      return u
    M = 2 ** (L - 1).bit_length()
    s += [0] * (M - sl)
    t += [0] * (M - tl)
    self.fft(s)
    self.fft(t)
    for i in range(M):
      s[i] = s[i] * t[i] % self.mod
    self.ifft(s)
    invk = pow(M, self.mod - 2, self.mod)
    for i in range(L):
      s[i] = s[i] * invk % self.mod
    del s[L:]
    return s


ntt = NumberTheroemTransform(998244353, 3)
import sys
read = sys.stdin.buffer.read
readline = sys.stdin.buffer.readline

N, Q = map(int, readline().split())
AB = list(map(int, read().split()))
A = AB[:N]
B = AB[N:]

M = 2 ** (N - 1).bit_length()
seg = [ [1] for i in range(2 * M)]
for i in range(N):
    seg[M + i] = [A[i] - 1, 1]

for i in range(M - 1, 0, -1):
  seg[i] = ntt.multiply(seg[2 * i], seg[2 * i + 1])

for q in B:
  print(seg[1][q])
0