結果

問題 No.2670 Sum of Products of Interval Lengths
ユーザー suisen
提出日時 2023-10-30 01:16:24
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 755 ms / 2,000 ms
コード長 6,285 bytes
コンパイル時間 511 ms
コンパイル使用メモリ 82,468 KB
実行使用メモリ 139,628 KB
最終ジャッジ日時 2024-09-28 17:43:13
合計ジャッジ時間 10,602 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 17
権限があれば一括ダウンロードができます

ソースコード

diff #
プレゼンテーションモードにする

from typing import List
def bsf(x):
res = 0
while not (x & 1):
res += 1
x >>= 1
return res
P = 998244353
G = 3
rank2 = bsf(P - 1)
class NTT:
class __RootInitializer:
@staticmethod
def root():
return [pow(G, (P - 1) >> i, P) for i in range(0, rank2 + 1)]
@staticmethod
def iroot():
return [pow(pow(G, P - 2, P), (P - 1) >> i, P) for i in range(0, rank2 + 1)]
root = __RootInitializer.root()
iroot = __RootInitializer.iroot()
class __RateInitializer:
@staticmethod
def rates(root: List[int], iroot: List[int]):
rate2 = [0] * max(0, rank2 - 1)
irate2 = [0] * max(0, rank2 - 1)
prod = iprod = 1
for i in range(rank2 - 1):
rate2[i] = root[i + 2] * prod % P
irate2[i] = iroot[i + 2] * iprod % P
prod = prod * iroot[i + 2] % P
iprod = iprod * root[i + 2] % P
rate3 = [0] * max(0, rank2 - 2)
irate3 = [0] * max(0, rank2 - 2)
prod = iprod = 1
for i in range(rank2 - 2):
rate3[i] = root[i + 3] * prod % P
irate3[i] = iroot[i + 3] * iprod % P
prod = prod * iroot[i + 3] % P
iprod = iprod * root[i + 3] % P
return rate2, irate2, rate3, irate3
rate2, irate2, rate3, irate3 = __RateInitializer.rates(__RootInitializer.root(), __RootInitializer.iroot())
@staticmethod
def butterfly(a: List[int]) -> None:
n = len(a)
h = bsf(n)
l = 0
while l < h:
if h - l == 1:
p = 1 << (h - l - 1)
rot = 1
for s in range(1 << l):
offset = s << (h - l)
for i in range(p):
u = a[i + offset]
v = a[i + offset + p] * rot
a[i + offset] = (u + v) % P
a[i + offset + p] = (u - v) % P
if s + 1 != 1 << l:
rot = rot * NTT.rate2[bsf(~s)] % P
l += 1
else:
p = 1 << (h - l - 2)
rot, imag = 1, NTT.root[2]
for s in range(1 << l):
rot2 = rot * rot % P
rot3 = rot2 * rot % P
offset = s << (h - l)
for i in range(p):
a0 = a[i + offset]
a1 = a[i + offset + p] * rot
a2 = a[i + offset + 2 * p] * rot2
a3 = a[i + offset + 3 * p] * rot3
a1na3imag = (a1 - a3) % P * imag
a[i + offset] = (a0 + a2 + a1 + a3) % P
a[i + offset + 1 * p] = (a0 + a2 - a1 - a3) % P
a[i + offset + 2 * p] = (a0 - a2 + a1na3imag) % P
a[i + offset + 3 * p] = (a0 - a2 - a1na3imag) % P
if s + 1 != (1 << l):
rot = rot * NTT.rate3[bsf(~s)] % P
l += 2
@staticmethod
def butterfly_inv(a : List[int]) -> None:
n = len(a)
h = bsf(n)
l = h
while l:
if l == 1:
p = 1 << (h - l)
irot = 1
for s in range(1 << (l - 1)):
offset = s << (h - l + 1)
for i in range(p):
u = a[i + offset]
v = a[i + offset + p]
a[i + offset] = (u + v) % P
a[i + offset + p] = ((u - v) * irot) % P
if s + 1 != 1 << (l - 1):
irot = irot * NTT.irate2[bsf(~s)] % P
l -= 1
else:
p = 1 << (h - l)
irot = 1
iimag = NTT.iroot[2]
for s in range(1 << (l - 2)):
irot2 = irot * irot % P
irot3 = irot2 * irot % P
offset = s << (h - l + 2)
for i in range(p):
a0 = a[i + offset]
a1 = a[i + offset + p]
a2 = a[i + offset + 2 * p]
a3 = a[i + offset + 3 * p]
a2na3iimag = (a2 - a3) * iimag % P
a[i + offset] = (a0 + a1 + a2 + a3) % P
a[i + offset + p] = ((a0 - a1 + a2na3iimag) * irot) % P
a[i + offset + 2 * p] = ((a0 + a1 - a2 - a3) * irot2) % P
a[i + offset + 3 * p] = ((a0 - a1 - a2na3iimag) * irot3) % P
if s + 1 != 1 << (l - 2):
irot = irot * NTT.irate3[bsf(~s)] % P
l -= 2
@staticmethod
def convolution(a, b):
n = len(a)
m = len(b)
if not a or not b:
return []
if min(n, m) <= 40:
if n < m:
n, m = m, n
a, b = b, a
res = [0] * (n + m - 1)
for i in range(n):
for j in range(m):
res[i + j] += a[i] * b[j]
res[i + j] %= P
return res
z = 1 << ((n + m - 1).bit_length())
iz = pow(z, P - 2, P)
a += [0] * (z - n)
b += [0] * (z - m)
NTT.butterfly(a)
NTT.butterfly(b)
c = [a[i] * b[i] % P * iz % P for i in range(z)]
NTT.butterfly_inv(c)
return c[:n + m - 1]
def inv(f):
assert f[0]
n = len(f)
ret = [pow(f[0], P - 2, P)]
i = 1
while i < n:
tmp = NTT.convolution(ret[:], f[: i << 1])
for j in range(len(tmp)):
if j == 0:
tmp[j] = (2 - tmp[j]) % P
else:
tmp[j] = -tmp[j] % P
ret = NTT.convolution(ret, tmp)[: i << 1]
i <<= 1
return ret[:n]
n, m = map(int, input().split())
f = [0] * (n + 1)
f[1] = 1
for i in range(2, n + 1):
f[i] = (f[i - 1] - f[i - 2]) % P
for i in range(1, n + 1):
f[i] = f[i] * (max(0, m - i + 1) % P) % P
f[0] = 1
for i in range(1, n + 1):
f[i] = -f[i] % P
print(inv(f)[n] % P)
הההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההההה
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
0