結果

問題 No.1839 Concatenation Matrix
ユーザー hitonanodehitonanode
提出日時 2021-12-19 20:16:04
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 2,845 ms / 3,500 ms
コード長 3,444 bytes
コンパイル時間 220 ms
コンパイル使用メモリ 82,432 KB
実行使用メモリ 162,324 KB
最終ジャッジ日時 2024-09-15 14:47:19
合計ジャッジ時間 22,076 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 45 ms
53,120 KB
testcase_01 AC 44 ms
52,992 KB
testcase_02 AC 45 ms
53,760 KB
testcase_03 AC 44 ms
53,120 KB
testcase_04 AC 45 ms
53,248 KB
testcase_05 AC 46 ms
53,120 KB
testcase_06 AC 45 ms
53,760 KB
testcase_07 AC 136 ms
77,184 KB
testcase_08 AC 222 ms
79,032 KB
testcase_09 AC 236 ms
79,252 KB
testcase_10 AC 803 ms
100,128 KB
testcase_11 AC 2,777 ms
162,324 KB
testcase_12 AC 2,644 ms
161,560 KB
testcase_13 AC 2,707 ms
162,128 KB
testcase_14 AC 2,710 ms
162,216 KB
testcase_15 AC 2,845 ms
162,296 KB
testcase_16 AC 1,575 ms
127,488 KB
testcase_17 AC 1,638 ms
126,736 KB
testcase_18 AC 2,033 ms
136,636 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

# https://atcoder.jp/contests/abl/submissions/17042688
import __pypy__
import heapq


class NumberTheroemTransform:
  def __init__(self, mod, pr):
    self.mod = mod
    self.pr = pr
    self.M = 2
    self.w = [1]
    self.y = [1]

  def setwy(self, M):
    if M <= self.M:
      return
    self.w += [0] * ((M - self.M) // 2)
    self.y += [0] * ((M - self.M) // 2)
    self.M = M
    z = pow(self.pr, (self.mod - 1) // self.M, self.mod)
    x = pow(z, self.mod - 2, self.mod)
    j = M // 4
    while j:
      self.w[j] = z
      z = z * z % self.mod
      self.y[j] = x
      x = x * x % self.mod
      j //= 2
    self.y[0] = 1
    self.w[0] = 1
    j = self.M // 2
    js = 2
    while js < j:
      z = self.w[js]
      x = self.y[js]
      for k2 in range(js):
        self.w[k2 + js] = self.w[k2] * z % self.mod
        self.y[k2 + js] = self.y[k2] * x % self.mod
      js *= 2

  def fft(self, a):
    mod = self.mod
    self.setwy(len(a))
    u = 1
    v = len(a) >> 1
    while v:
      for j in range(v):
        a[j], a[j + v] = a[j] + a[j + v], a[j] - a[j + v]
        if a[j] >= mod:
          a[j] -= mod
        if a[j + v] < 0:
          a[j + v] += mod
      for jh in range(1, u):
        wj = self.w[jh]
        js = jh * v * 2
        je = js + v
        for j in range(js, je):
          ajv = wj * a[j + v] % mod
          a[j + v] = a[j] - ajv
          if a[j + v] < 0:
            a[j + v] += mod
          a[j] = a[j] + ajv
          if a[j] >= mod:
            a[j] -= mod
      u *= 2
      v >>= 1

  def ifft(self, a):
    mod = self.mod
    self.setwy(len(a))
    u = len(a) >> 1
    v = 1
    while u:
      for j in range(v):
        a[j], a[j + v] = a[j] + a[j + v], a[j] - a[j + v]
        if a[j] >= mod:
          a[j] -= mod
        if a[j + v] < 0:
          a[j + v] += mod
      for jh in range(1, u):
        wj = self.y[jh]
        js = jh * v * 2
        je = js + v
        for j in range(js, je):
          ajv = a[j] - a[j + v]
          if ajv < 0:
            ajv += mod
          a[j] = a[j] + a[j + v]
          if a[j] >= mod:
            a[j] -= mod
          a[j + v] = wj * ajv % mod
      u >>= 1
      v *= 2

  def multiply(self, s, t):
    sl, tl = len(s), len(t)
    L = sl + tl - 1
    M = 2 ** (L - 1).bit_length()
    s += [0] * (M - sl)
    t += [0] * (M - tl)
    self.fft(s)
    self.fft(t)
    for i in range(M):
      s[i] = s[i] * t[i] % self.mod
    self.ifft(s)
    invk = pow(M, self.mod - 2, self.mod)
    for i in range(L):
      s[i] = s[i] * invk % self.mod
    del s[L:]
    return s

def main():
    ntt = NumberTheroemTransform(998244353, 3)

    MOD = 998244353

    N = int(input())
    A = list(map(int, input().split()))

    class C:
        def __init__(sl, a):
            sl.a = a

        def __lt__(sl, ot):
            return len(sl.a) < len(ot.a)

    vs = [None] * (N - 1 + N - 2)

    p10 = 10
    for i in range(N - 2, N - 2 + N - 1):
        vs[i] = [1, p10]
        p10 = p10 * p10 % MOD
    
    for i in range(N - 3, -1, -1):
        vs[i] = ntt.multiply(vs[i * 2 + 1], vs[i * 2 + 2])

    vs = ntt.multiply(vs[0], A)

    ret = [0] * N
    for i, v in enumerate(vs):
        ret[(i + 1) % N] += v
        ret[(i + 1) % N] %= MOD
    print(*ret, sep='\n')


if __name__ == "__main__":
    main()
0