結果

問題 No.2883 K-powered Sum of Fibonacci
ユーザー PNJPNJ
提出日時 2024-09-08 17:50:33
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 156 ms / 3,000 ms
コード長 7,461 bytes
コンパイル時間 552 ms
コンパイル使用メモリ 82,248 KB
実行使用メモリ 77,512 KB
最終ジャッジ日時 2024-09-08 17:50:39
合計ジャッジ時間 5,804 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 52 ms
63,864 KB
testcase_01 AC 54 ms
64,524 KB
testcase_02 AC 54 ms
64,612 KB
testcase_03 AC 54 ms
65,376 KB
testcase_04 AC 55 ms
64,784 KB
testcase_05 AC 125 ms
77,236 KB
testcase_06 AC 55 ms
66,384 KB
testcase_07 AC 111 ms
77,488 KB
testcase_08 AC 63 ms
67,568 KB
testcase_09 AC 60 ms
67,524 KB
testcase_10 AC 63 ms
68,776 KB
testcase_11 AC 123 ms
77,500 KB
testcase_12 AC 67 ms
71,344 KB
testcase_13 AC 61 ms
69,156 KB
testcase_14 AC 123 ms
77,368 KB
testcase_15 AC 69 ms
70,872 KB
testcase_16 AC 56 ms
64,904 KB
testcase_17 AC 82 ms
69,868 KB
testcase_18 AC 62 ms
67,756 KB
testcase_19 AC 67 ms
71,000 KB
testcase_20 AC 127 ms
77,252 KB
testcase_21 AC 128 ms
77,512 KB
testcase_22 AC 128 ms
77,372 KB
testcase_23 AC 125 ms
77,480 KB
testcase_24 AC 126 ms
77,372 KB
testcase_25 AC 128 ms
77,472 KB
testcase_26 AC 128 ms
77,348 KB
testcase_27 AC 156 ms
77,484 KB
testcase_28 AC 128 ms
77,248 KB
testcase_29 AC 125 ms
77,296 KB
testcase_30 AC 76 ms
73,656 KB
testcase_31 AC 54 ms
64,340 KB
testcase_32 AC 54 ms
64,392 KB
testcase_33 AC 91 ms
76,392 KB
testcase_34 AC 54 ms
65,184 KB
testcase_35 AC 55 ms
65,168 KB
testcase_36 AC 57 ms
66,476 KB
testcase_37 AC 59 ms
67,412 KB
testcase_38 AC 60 ms
67,864 KB
testcase_39 AC 56 ms
66,596 KB
testcase_40 AC 54 ms
64,340 KB
testcase_41 AC 55 ms
63,516 KB
testcase_42 AC 129 ms
77,228 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

mod = 998244353
NTT_friend = [120586241,167772161,469762049,754974721,880803841,924844033,943718401,998244353,1045430273,1051721729,1053818881]
NTT_dict = {}
for i in range(len(NTT_friend)):
  NTT_dict[NTT_friend[i]] = i
NTT_info = [[20,74066978],[25,17],[26,30],[24,362],[23,211],[21,44009197],[22,663003469],[23,31],[20,363],[20,330],[20,2789]]

def popcount(n):
  c=(n&0x5555555555555555)+((n>>1)&0x5555555555555555)
  c=(c&0x3333333333333333)+((c>>2)&0x3333333333333333)
  c=(c&0x0f0f0f0f0f0f0f0f)+((c>>4)&0x0f0f0f0f0f0f0f0f)
  c=(c&0x00ff00ff00ff00ff)+((c>>8)&0x00ff00ff00ff00ff)
  c=(c&0x0000ffff0000ffff)+((c>>16)&0x0000ffff0000ffff)
  c=(c&0x00000000ffffffff)+((c>>32)&0x00000000ffffffff)
  return c

def topbit(n):
  h = n.bit_length()
  h -= 1
  return h

def prepared_fft(mod = 998244353):
  rank2 = NTT_info[NTT_dict[mod]][0]
  root,iroot = [0] * 30,[0] * 30
  rate2,irate2= [0] * 30,[0] * 30
  rate3,irate3= [0] * 30,[0] * 30

  root[rank2] = NTT_info[NTT_dict[mod]][1]
  iroot[rank2] = pow(root[rank2],mod - 2,mod)
  for i in range(rank2-1,-1,-1):
    root[i] = root[i+1] * root[i+1] % mod
    iroot[i] = iroot[i+1] * iroot[i+1] % mod

  prod,iprod = 1,1
  for i in range(rank2-1):
    rate2[i] = root[i + 2] * prod % mod
    irate2[i] = iroot[i + 2] * iprod % mod
    prod = prod * iroot[i + 2] % mod
    iprod = iprod * root[i + 2] % mod
  
  prod,iprod = 1,1
  for i in range(rank2-2):
    rate3[i] = root[i + 3] * prod % mod
    irate3[i] = iroot[i + 3] * iprod % mod
    prod = prod * iroot[i + 3] % mod
    iprod = iprod * root[i + 3] % mod
  
  return root,iroot,rate2,irate2,rate3,irate3

root,iroot,rate2,irate2,rate3,irate3 = prepared_fft()

def ntt(a):
  n = len(a)
  h = topbit(n)
  assert (n == 1 << h)
  le = 0
  while le < h:
    if h - le == 1:
      p = 1 << (h - le - 1)
      rot = 1
      for s in range(1 << le):
        offset = s << (h - le)
        for i in range(p):
          l = a[i + offset]
          r = a[i + offset + p] * rot % mod
          a[i + offset] = (l + r) % mod
          a[i + offset + p] = (l - r) % mod
        rot = rot * rate2[topbit(~s & -~s)] % mod
      le += 1
    else:
      p = 1 << (h - le - 2)
      rot,imag = 1,root[2]
      for s in range(1 << le):
        rot2 = rot * rot % mod
        rot3 = rot2 * rot % mod
        offset = s << (h - le)
        for i in range(p):
          a0 = a[i + offset]
          a1 = a[i + offset + p] * rot
          a2 = a[i + offset + p * 2] * rot2
          a3 = a[i + offset + p * 3] * rot3
          a1na3imag = (a1 - a3) % mod * imag
          a[i + offset] = (a0 + a2 + a1 + a3) % mod
          a[i + offset + p] = (a0 + a2 - a1 - a3) % mod
          a[i + offset + p * 2] = (a0 - a2 + a1na3imag) % mod
          a[i + offset + p * 3] = (a0 - a2 - a1na3imag) % mod
        rot = rot * rate3[topbit(~s & -~s)] % mod
      le += 2

def intt(a):
  n = len(a)
  h = topbit(n)
  assert (n == 1 << h)
  coef = pow(n,mod - 2,mod)
  for i in range(n):
    a[i] = a[i] * coef % mod
  le = h
  while le:
    if le == 1:
      p = 1 << (h - le)
      irot = 1
      for s in range(1 << (le - 1)):
        offset = s << (h - le + 1)
        for i in range(p):
          l = a[i + offset]
          r = a[i + offset + p]
          a[i + offset] = (l + r) % mod
          a[i + offset + p] = (l - r) * irot % mod
        irot = irot * irate2[topbit(~s & -~s)] % mod
      le -= 1
    else:
      p = 1 << (h - le)
      irot,iimag = 1,iroot[2]
      for s in range(1 << (le - 2)):
        irot2 = irot * irot % mod
        irot3 = irot2 * irot % mod
        offset = s << (h - le + 2)
        for i in range(p):
          a0 = a[i + offset]
          a1 = a[i + offset + p]
          a2 = a[i + offset + p * 2]
          a3 = a[i + offset + p * 3]
          a2na3iimag = (a2 - a3) * iimag % mod
          a[i + offset] = (a0 + a1 + a2 + a3) % mod
          a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % mod
          a[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % mod
          a[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % mod
        irot *= irate3[topbit(~s & -~s)]
        irot %= mod
      le -= 2

def convolute_naive(a,b):
  res = [0] * (len(a) + len(b) - 1)
  for i in range(len(a)):
    for j in range(len(b)):
      res[i+j] = (res[i+j] + a[i] * b[j] % mod) % mod
  return res

def convolute(a,b):
  s = a[:]
  t = b[:]
  n = len(s)
  m = len(t)
  if min(n,m) <= 60:
    return convolute_naive(s,t)
  le = 1
  while le < n + m - 1:
    le *= 2
  s += [0] * (le - n)
  t += [0] * (le - m)
  ntt(s)
  ntt(t)
  for i in range(le):
    s[i] = s[i] * t[i] % mod
  intt(s)
  s = s[:n + m - 1]
  return s

def fps_inv(f,deg = -1):
  assert (f[0] != 0)
  if deg == -1:
    deg = len(f)
  res = [0] * deg
  res[0] = pow(f[0],mod-2,mod)
  d = 1
  while d < deg:
    a = [0] * (d << 1)
    tmp = min(len(f),d << 1)
    a[:tmp] = f[:tmp]
    b = [0] * (d << 1)
    b[:d] = res[:d]
    ntt(a)
    ntt(b)
    for i in range(d << 1):
      a[i] = a[i] * b[i] % mod
    intt(a)
    a[:d] = [0] * d
    ntt(a)
    for i in range(d << 1):
      a[i] = a[i] * b[i] % mod
    intt(a)
    for j in range(d,min(d << 1,deg)):
      if a[j]:
        res[j] = mod - a[j]
      else:
        res[j] = 0
    d <<= 1
  return res

def fps_div(f,g):
  n,m = len(f),len(g)
  if n < m:
    return [],f
  rev_f = f[:]
  rev_f = rev_f[::-1]
  rev_g = g[:]
  rev_g = rev_g[::-1]
  rev_q = convolute(rev_f,fps_inv(rev_g,n-m+1))[:n-m+1]
  q = rev_q[:]
  q = q[::-1]
  p = convolute(g,q)
  r = f[:]
  for i in range(min(len(p),len(r))):
    r[i] -= p[i]
    r[i] %= mod
  while len(r):
    if r[-1] != 0:
      break
    r.pop()
  return q,r

def Bostan_Mori(N,P,Q): # P(x) / Q(x)
  assert (len(P))
  d = len(Q) - 1
  n = N
  while True:
    if n == 0:
      return P[0]
    QQ = [Q[i] for i in range(d+1)]
    for i in range(1,d+1,2):
      QQ[i] = mod - QQ[i]
    UU = convolute(P,QQ)
    U_e = []
    U_o = []
    for i in range(len(UU)):
      if i % 2:
        U_o.append(UU[i])
      else:
        U_e.append(UU[i])
    V = convolute(Q,QQ)
    Q = [V[2*i] for i in range(d+1)]
    if n % 2:
      P = U_o[:]
    else:
      P = U_e[:]
    n //= 2

def gauss_jordan(A,b):
  n = len(A)
  for i in range(n):
    pivot = A[i][i]
    for j in range(i,n):
      A[i][j] = A[i][j] * pow(pivot,-1,mod) % mod
    b[i][0] = b[i][0] * pow(pivot,-1,mod) % mod
    for j in range(i + 1,n):
      pivot = A[j][i]
      for k in range(i,n):
        A[j][k] = (A[j][k] - pivot * A[i][k] % mod) % mod
      b[j][0] = (b[j][0] - b[i][0] * pivot % mod) % mod
  for i in range(n):
    for j in range(i):
      if A[j][i]:
        b[j][0] = (b[j][0] - A[j][i] * b[i][0] % mod) % mod
        A[j][i] = 0
  return A,b

def berlekamp_massey(A):
  n = len(A)
  B,C = [1],[1]
  l,m,p = 0,1,1
  for i in range(n):
    d = A[i]
    for j in range(1,l + 1):
      d = (d + C[j] * A[i - j] % mod) % mod
    if d == 0:
      m += 1
      continue
    T = C[:]
    q = pow(p,-1,mod) * d % mod
    while len(C) < len(B) + m:
      C.append(0)
    for j in range(len(B)):
      b = B[j]
      C[j + m] = (C[j + m] - q * b % mod) % mod
    if 2 * l <= i:
      B = T[:]
      l,m,p = i + 1 - l,1,d
    else:
      m += 1
  return C

N,K = map(int,input().split())
fib = [1,1]
F = [1,1]
for i in range(2,1000):
  fib.append((fib[-1] + fib[-2]) % mod)
  F.append(pow(fib[-1],K,mod))

g = berlekamp_massey(F)
f = convolute(F,g)[:len(g) - 1]
g = convolute(g,[1,mod - 1])
print(Bostan_Mori(N - 1,f,g))
0