結果

問題 No.1889 K Consecutive Ks (Hard)
ユーザー miscalcmiscalc
提出日時 2021-10-12 17:42:10
言語 Python3
(3.12.2 + numpy 1.26.4 + scipy 1.12.0)
結果
TLE  
実行時間 -
コード長 1,736 bytes
コンパイル時間 233 ms
コンパイル使用メモリ 13,056 KB
実行使用メモリ 103,596 KB
最終ジャッジ日時 2024-04-15 00:49:56
合計ジャッジ時間 9,254 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 499 ms
103,596 KB
testcase_01 AC 506 ms
44,460 KB
testcase_02 TLE -
testcase_03 -- -
testcase_04 -- -
testcase_05 -- -
testcase_06 -- -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

import numpy as np

mod = 998244353

# https://maspypy.com/%E6%95%B0%E5%AD%A6%E3%83%BBnumpy-%E9%AB%98%E9%80%9F%E3%83%95%E3%83%BC%E3%83%AA%E3%82%A8%E5%A4%89%E6%8F%9Bfft%E3%81%AB%E3%82%88%E3%82%8B%E7%95%B3%E3%81%BF%E8%BE%BC%E3%81%BF#toc7
def convolution(f, g):
  fft_len = 1
  while 2 * fft_len < len(f) + len(g) - 1:
    fft_len *= 2
  fft_len *= 2

  Ff = np.fft.rfft(f, fft_len)
  Fg = np.fft.rfft(g, fft_len)
  Fh = Ff * Fg

  h = np.fft.irfft(Fh, fft_len)
  h = np.rint(h).astype(np.int64)

  return h[: len(f) + len(g) - 1]

def convolution2(f, g):
  f1, f2 = np.divmod(f, 1 << 15)
  g1, g2 = np.divmod(g, 1 << 15)

  a = convolution(f1, g1) % mod
  c = convolution(f2, g2) % mod
  b = (convolution(f1 + f2, g1 + g2) - (a + c)) % mod

  h = (a << 30) + (b << 15) + c
  return h % mod

n, m = map(int, input().split())

d = np.zeros(n + 10, np.int64)
for i in range(2, m + 1):
  k = 1
  while k * i <= n:
    d[k * i - 1] += 1
    k += 1

pw = np.zeros(n + 10, np.int64)
pw[1], pw[2] = 1, m - 2
for i in range(2, n):
  pw[i + 1] = (m - 1) * pw[i] % mod

dp1 = np.zeros(n + 10, np.int64)
dp2 = np.ones(n + 10, np.int64)
for i in range(n):
  dp2[i + 1] = (m - 1) * dp2[i] % mod

def onlineconvolution(l, r):
  if l + 1 == r:
    return
  c = (l + r) // 2
  onlineconvolution(l, c)
  dp20 = dp2[l : c]
  d0 = d[: r - l]
  r1 = convolution2(dp20, d0)
  for i in range(c, r):
    dp1[i] += r1[i - l]
    if dp1[i] >= mod:
      dp1[i] -= mod
  dp10 = dp1[l : c]
  pw0 = pw[: r - l]
  r2 = convolution2(dp10, pw0)
  for i in range(c, r):
    dp2[i] += mod - r2[i - l]
    if dp2[i] >= mod:
      dp2[i] -= mod
  onlineconvolution(c, r)

onlineconvolution(0, n + 1)
ans = pow(m, n, mod) + mod - dp2[n]
if ans >= mod:
  ans -= mod
print(ans)
0