結果
| 問題 |
No.3044 よくあるカエルさん
|
| コンテスト | |
| ユーザー |
PNJ
|
| 提出日時 | 2025-03-01 02:40:33 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 230 ms / 2,000 ms |
| コード長 | 13,830 bytes |
| コンパイル時間 | 360 ms |
| コンパイル使用メモリ | 82,376 KB |
| 実行使用メモリ | 101,248 KB |
| 最終ジャッジ日時 | 2025-03-01 02:40:39 |
| 合計ジャッジ時間 | 5,514 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 20 |
ソースコード
mod = 998244353
n = 10 ** 6
inv = [1 for j in range(n + 1)]
for a in range(2, n + 1):
# ax + py = 1 <=> rx + p(-x - qy) = -q => x = -(inv[r]) * (p // a) (r = p % a)
res = (mod - inv[mod % a]) * (mod // a)
inv[a] = res % mod
def mod_inv(a, mod = 998244353):
if mod == 1:
return 0
a %= mod
b, s, t = mod, 1, 0
while True:
if a == 1:
return s
t -= (b // a) * s
b %= a
if b == 1:
return t + mod
s -= (a // b) * t
a %= b
fact = [1 for i in range(n + 1)]
for i in range(1, n + 1):
fact[i] = fact[i - 1] * i % mod
fact_inv = [1 for i in range(n + 1)]
fact_inv[-1] = pow(fact[-1], mod - 2, mod)
for i in range(n, 0, -1):
fact_inv[i - 1] = fact_inv[i] * i % mod
def binom(n, r):
if n < r or n < 0 or r < 0:
return 0
res = fact_inv[n - r] * fact_inv[r] % mod
res *= fact[n]
res %= mod
return res
def Garner(Rem, MOD, mod):
Mod = MOD[:]
Rem.append(0)
Mod.append(mod)
n = len(Mod)
coffs = [1] * n
constants = [0] * n
for i in range(n - 1):
v = (Rem[i] - constants[i]) * mod_inv(coffs[i], Mod[i]) % Mod[i]
for j in range(i + 1, n):
constants[j] = (constants[j] + coffs[j] * v) % Mod[j]
coffs[j] = (coffs[j] * Mod[i]) % Mod[j]
return constants[-1]
import random
def Tonelli_Shanks(a, p = 998244353):
a %= p
if a < 2:
return a
if pow(a, (p - 1) // 2, p) != 1:
return -1
if p % 4 == 3:
return pow(a, (p + 1) // 4, p)
b = 1
if p == 998244353:
b = 3
else:
while pow(b, (p - 1) // 2, p) == 1:
b = random.randint(2, p - 1)
q = p - 1
Q = 0
while q % 2 == 0:
Q += 1
q >>= 1
x = pow(a, (q + 1) // 2, p)
b = pow(b, q, p)
shift = 2
while x * x % p != a:
error = pow(a, -1, p) * x * x % p
if pow(error, 1 << (Q - shift), p) != 1:
x = x * b % p
b = b * b % p
shift += 1
return x
def NTT_info(mod):
if mod == 998244353:
return (23, 31, 0)
if mod == 120586241:
return (20, 74066978, 1)
if mod == 167772161:
return (25, 17, 2)
if mod == 469762049:
return (26, 30, 3)
if mod == 754974721:
return (24, 362, 4)
if mod == 880803841:
return (23, 211, 5)
if mod == 924844033:
return (21, 44009197, 6)
if mod == 943718401:
return (22, 663003469, 7)
if mod == 1045430273:
return (20, 363, 8)
if mod == 1051721729:
return (20, 330, 9)
if mod == 1053818881:
return (20, 2789, 10)
return (0, -1, -1)
def prepared_fft(mod = 998244353):
rank2 = NTT_info(mod)[0]
root, iroot = [0] * 30, [0] * 30
rate2, irate2 = [0] * 30, [0] * 30
rate3, irate3 = [0] * 30, [0] * 30
root[rank2] = NTT_info(mod)[1]
iroot[rank2] = pow(root[rank2], mod - 2, mod)
for i in range(rank2 - 1, -1, -1):
root[i] = root[i + 1] * root[i + 1] % mod
iroot[i] = iroot[i + 1] * iroot[i + 1] % mod
prod, iprod = 1, 1
for i in range(rank2 - 1):
rate2[i] = root[i + 2] * prod % mod
irate2[i] = iroot[i + 2] * iprod % mod
prod = prod * iroot[i + 2] % mod
iprod = iprod * root[i + 2] % mod
prod, iprod = 1, 1
for i in range(rank2 - 2):
rate3[i] = root[i + 3] * prod % mod
irate3[i] = iroot[i + 3] * iprod % mod
prod = prod * iroot[i + 3] % mod
iprod = iprod * root[i + 3] % mod
return root, iroot, rate2, irate2, rate3, irate3
root, iroot, rate2, irate2, rate3, irate3 = [[] for _ in range(11)], [[] for _ in range(11)], [[] for _ in range(11)], [[] for _ in range(11)], [[] for _ in range(11)], [[] for _ in range(11)]
def ntt(a, inverse = 0, mod = 998244353):
idx = NTT_info(mod)[2]
if len(root[idx]) == 0:
root[idx], iroot[idx], rate2[idx], irate2[idx], rate3[idx], irate3[idx] = prepared_fft(mod)
n = len(a)
h = (n - 1).bit_length()
assert (n == 1 << h)
if inverse == 0:
le = 0
while le < h:
if h - le == 1:
p = 1 << (h - le - 1)
rot = 1
for s in range(1 << le):
offset = s << (h - le)
for i in range(p):
l = a[i + offset]
r = a[i + offset + p] * rot % mod
a[i + offset] = (l + r) % mod
a[i + offset + p] = (l - r) % mod
rot = rot * rate2[idx][((~s & -~s) - 1).bit_length()] % mod
le += 1
else:
p = 1 << (h - le - 2)
rot, imag = 1, root[idx][2]
for s in range(1 << le):
rot2 = rot * rot % mod
rot3 = rot2 * rot % mod
offset = s << (h - le)
for i in range(p):
a0 = a[i + offset]
a1 = a[i + offset + p] * rot
a2 = a[i + offset + p * 2] * rot2
a3 = a[i + offset + p * 3] * rot3
a1na3imag = (a1 - a3) % mod * imag
a[i + offset] = (a0 + a2 + a1 + a3) % mod
a[i + offset + p] = (a0 + a2 - a1 - a3) % mod
a[i + offset + p * 2] = (a0 - a2 + a1na3imag) % mod
a[i + offset + p * 3] = (a0 - a2 - a1na3imag) % mod
rot = rot * rate3[idx][((~s & -~s) - 1).bit_length()] % mod
le += 2
else:
coef = pow(n, mod - 2, mod)
for i in range(n):
a[i] = a[i] * coef % mod
le = h
while le:
if le == 1:
p = 1 << (h - le)
irot = 1
for s in range(1 << (le - 1)):
offset = s << (h - le + 1)
for i in range(p):
l = a[i + offset]
r = a[i + offset + p]
a[i + offset] = (l + r) % mod
a[i + offset + p] = (l - r) * irot % mod
irot = irot * irate2[idx][((~s & -~s) - 1).bit_length()] % mod
le -= 1
else:
p = 1 << (h - le)
irot, iimag = 1, iroot[idx][2]
for s in range(1 << (le - 2)):
irot2 = irot * irot % mod
irot3 = irot2 * irot % mod
offset = s << (h - le + 2)
for i in range(p):
a0 = a[i + offset]
a1 = a[i + offset + p]
a2 = a[i + offset + p * 2]
a3 = a[i + offset + p * 3]
a2na3iimag = (a2 - a3) * iimag % mod
a[i + offset] = (a0 + a1 + a2 + a3) % mod
a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % mod
a[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % mod
a[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % mod
irot *= irate3[idx][((~s & -~s) - 1).bit_length()]
irot %= mod
le -= 2
def transposed_ntt(a, inverse = 0, mod = 998244353):
# bit_reverse fft の転置
n = len(a)
if inverse == 0:
ntt(a, 1, mod)
for i in range(1, n // 2):
a[i], a[n - i] = a[n - i], a[i]
for i in range(n):
a[i] = a[i] * n % mod
return
else:
for i in range(1, n // 2):
a[i], a[n - i] = a[n - i], a[i]
ntt(a, 0, mod)
n_inv = pow(n, mod - 2, mod)
for i in range(n):
a[i] = a[i] * n_inv % mod
return
def convolution_naive(a, b, mod = 998244353):
res = [0] * (len(a) + len(b) - 1)
for i in range(len(a)):
for j in range(len(b)):
res[i + j] = (res[i + j] + a[i] * b[j] % mod) % mod
return res
def convolution_ntt(a, b, mod = 998244353):
s = a[:]
t = b[:]
n = len(s)
m = len(t)
if min(n, m) <= 60:
return convolution_naive(s, t, mod)
le = 1
while le < n + m - 1:
le *= 2
s += [0] * (le - n)
t += [0] * (le - m)
ntt(s, 0, mod)
ntt(t, 0, mod)
for i in range(le):
s[i] = s[i] * t[i] % mod
ntt(s, 1, mod)
s = s[:n + m - 1]
return s
def convolution_garner(f, g, mod):
MOD = [167772161, 469762049, 754974721]
flag = 0
if (mod - 1) * (mod - 1) * min(len(f), len(g)) >= 167772161 * 469762049 * 754974721:
MOD += [880803841, 998244353]
flag = 1
H = []
for i in range(len(MOD)):
H.append(convolution_ntt(f, g, MOD[i]))
h = []
for i in range(len(H[0])):
Rem = [H[0][i], H[1][i], H[2][i]]
if flag:
Rem += [H[3][i], H[4][i]]
h.append(Garner(Rem, MOD, mod) % mod)
return h
def convolution(f, g, mod = 998244353):
if NTT_info(mod)[1] == -1:
return convolution_garner(f, g, mod)
return convolution_ntt(f, g, mod)
def fps_inv(f, deg = -1, mod = 998244353):
assert (f[0] != 0)
if deg == -1:
deg = len(f)
n = len(f)
# ntt_prime
if NTT_info(mod)[2] != -1:
g = [mod_inv(f[0], mod)] + [0 for _ in range(deg - 1)]
le = 1
while le < deg:
a = [0 for _ in range(2 * le)]
b = [0 for _ in range(2 * le)]
for i in range(min(n, 2 * le)):
a[i] = f[i]
for i in range(le):
b[i] = g[i]
ntt(a, 0, mod)
ntt(b, 0, mod)
for i in range(2 * le):
a[i] = a[i] * b[i] % mod
ntt(a, 1, mod)
for i in range(le):
a[i] = 0
ntt(a, 0, mod)
for i in range(2 * le):
a[i] = a[i] * b[i] % mod
ntt(a, 1, mod)
for j in range(le, min(deg, 2 * le)):
g[j] = (mod - a[j]) % mod
le *= 2
return g
# not ntt prime
# doubling
else:
g = [0 for _ in range(deg)]
g[0] = mod_inv(f[0], mod)
gg = []
le = 1
while le < deg:
gg = convolution(g[:le], g[:le])
ff = f[:min(2 * le, n)]
gg = convolution(ff, gg)
for i in range(min(deg, 2 * le)):
g[i] = (g[i] + g[i] - gg[i]) % mod
le *= 2
return g[:deg]
def fps_exp(f, deg = -1, mod = 998244353):
if deg == -1:
deg = len(f)
n = len(f)
assert (n > 0)
assert (f[0] == 0)
# ntt_prime
if NTT_info(mod)[2] != -1:
g = [1, 0]
if len(f) > 1:
g[1] = f[1]
h = [1]
p = []
q = [1, 1]
le = 2
while le < deg:
y = g + [0] * le
ntt(y, 0, mod)
p = q[:]
z = [y[i] * p[i] for i in range(le)]
ntt(z, 1, mod)
for i in range(le // 2):
z[i] = 0
ntt(z, 0, mod)
for i in range(len(p)):
z[i] = z[i] * (-p[i]) % mod
ntt(z, 1, mod)
for i in range(le // 2, le):
h.append(z[i])
q = h + [0] * le
ntt(q, 0, mod)
x = [0 for _ in range(le)]
for i in range(le - 1):
x[i] = f[i + 1] * (i + 1) % mod
ntt(x, 0, mod)
for i in range(le):
x[i] = x[i] * y[i] % mod
ntt(x, 1, mod)
for i in range(le - 1):
x[i] = (x[i] - g[i + 1] * (i + 1)) % mod
x += [0] * le
for i in range(le - 1):
x[le + i], x[i] = x[i], 0
ntt(x, 0, mod)
for i in range(2 * le):
x[i] = x[i] * q[i] % mod
ntt(x, 1, mod)
for i in range(len(x) - 2, -1, -1):
x[i + 1] = x[i] * mod_inv(i + 1, mod) % mod
for i in range(le, min(n, 2 * le)):
x[i] += f[i]
for i in range(le):
x[i] = 0
ntt(x, 0, mod)
for i in range(2 * le):
x[i] = x[i] * y[i] % mod
ntt(x, 1, mod)
for i in range(le, len(x)):
g.append(x[i])
le *= 2
return g[:deg]
# not ntt prime
# Newton's method
else:
log = 0
while (1 << log) < deg:
log += 1
ff = [0 for _ in range(1 << log)]
df = [0 for _ in range(1 << log)]
for i in range(min(n, (1 << log))):
ff[i] = f[i]
if i > 0:
df[i - 1] = i * f[i] % mod
g, h = [1], [1]
le = 1
for _ in range(log):
p = convolution(g, h)[:le]
p = convolution(p, h)[:le]
while len(h) < le:
h.append(0)
for i in range(le):
h[i] = (2 * h[i] - p[i]) % mod
p = df[:(le - 1)][:]
p = convolution(g, p)
p.append(0)
for i in range(2 * le - 1):
p[i] = (mod - p[i]) % mod
for i in range(le - 1):
p[i] = (p[i] + g[i + 1] * (i + 1)) % mod
p = convolution(p, h)[:2 * le - 1]
for i in range(le - 1):
p[i] = (p[i] + df[i]) % mod
p.append(0)
for i in range(2 * le - 2, -1, -1):
p[i + 1] = p[i] * mod_inv(i + 1, mod) % mod
p[0] = 0
for i in range(2 * le):
p[i] = (ff[i] - p[i]) % mod
p[0] = 1
g = convolution(g, p)[:2 * le]
le *= 2
return g[:deg]
def fps_log(f, deg = -1, mod = 998244353):
assert (f[0] == 1)
if deg == -1:
deg = len(f)
n = len(f)
df = [0 for _ in range(deg)]
for i in range(1, min(deg + 1, n)):
df[i - 1] = f[i] * i % mod
f_inv = fps_inv(f, deg, mod)
res = convolution(df, f_inv, mod)[:deg]
for i in range(deg - 2, -1, -1):
res[i + 1] = res[i] * mod_inv(i + 1) % mod
res[0] = 0
return res
def fps_pow(f, k, deg = -1, mod = 998244353):
if deg == -1:
deg = len(f)
if k == 0:
return [1] + [0] * (deg - 1)
n = len(f)
d = 0
while d < min(deg, n):
if f[d]:
break
d += 1
if d * k >= deg or d == n:
return [0] * deg
a = f[d]
a_inv = mod_inv(a, mod)
g = [0 for _ in range(deg - d * k)]
for i in range(min(deg - d * k, n - d)):
g[i] = f[i + d] * a_inv % mod
g = fps_log(g)
for i in range(deg - d * k):
g[i] = g[i] * k % mod
g = fps_exp(g)
a = pow(a, k, mod)
res = [0] * deg
for i in range(deg - d * k):
res[i + d * k] = g[i] * a % mod
return res
def Bostan_Mori(N, P, Q, mod = 998244353): # P(x) / Q(x)
assert (len(P))
d = len(Q) - 1
n = N
while True:
if n == 0:
return P[0]
QQ = [Q[i] for i in range(d + 1)]
for i in range(1, d + 1, 2):
QQ[i] = mod - QQ[i]
UU = convolution(P, QQ, mod)
U_e = []
U_o = []
for i in range(len(UU)):
if i % 2:
U_o.append(UU[i])
else:
U_e.append(UU[i])
V = convolution(Q, QQ, mod)
Q = [V[2 * i] for i in range(d + 1)]
if n % 2:
P = U_o[:]
else:
P = U_e[:]
n //= 2
N, T = map(int, input().split())
k, l = map(int, input().split())
k -= 1
l -= 1
f = [0 for i in range(T + 1)]
g = [0 for i in range(T + 1)]
for i in range(k):
f[1] = (f[1] - inv[6]) % mod
g[1] = (g[1] + inv[6]) % mod
for i in range(k, l):
f[2] = (f[2] - inv[6]) % mod
g[2] = (g[2] + inv[6]) % mod
for i in range(l, 6):
f[T] = (f[T] - inv[6]) % mod
g[T] = (g[T] + inv[6]) % mod
f[0] = 1
print(Bostan_Mori(N - 1, g, f))
PNJ