結果
問題 | No.3044 よくあるカエルさん |
ユーザー |
|
提出日時 | 2025-02-28 22:19:35 |
言語 | PyPy3 (7.3.15) |
結果 |
RE
|
実行時間 | - |
コード長 | 16,961 bytes |
コンパイル時間 | 267 ms |
コンパイル使用メモリ | 82,224 KB |
実行使用メモリ | 77,976 KB |
最終ジャッジ日時 | 2025-02-28 22:19:38 |
合計ジャッジ時間 | 2,958 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | WA * 3 RE * 17 |
ソースコード
# inputimport sysinput = sys.stdin.readlineII = lambda : int(input())MI = lambda : map(int, input().split())LI = lambda : [int(a) for a in input().split()]SI = lambda : input().rstrip()LLI = lambda n : [[int(a) for a in input().split()] for _ in range(n)]LSI = lambda n : [input().rstrip() for _ in range(n)]MI_1 = lambda : map(lambda x:int(x)-1, input().split())LI_1 = lambda : [int(a)-1 for a in input().split()]def graph(n:int, m:int, dir:bool=False, index:int=-1) -> list[set[int]]:edge = [set() for i in range(n+1+index)]for _ in range(m):a,b = map(int, input().split())a += indexb += indexedge[a].add(b)if not dir:edge[b].add(a)return edgedef graph_w(n:int, m:int, dir:bool=False, index:int=-1) -> list[set[tuple]]:edge = [set() for i in range(n+1+index)]for _ in range(m):a,b,c = map(int, input().split())a += indexb += indexedge[a].add((b,c))if not dir:edge[b].add((a,c))return edgemod = 998244353inf = 1001001001001001001ordalp = lambda s : ord(s)-65 if s.isupper() else ord(s)-97ordallalp = lambda s : ord(s)-39 if s.isupper() else ord(s)-97yes = lambda : print("Yes")no = lambda : print("No")yn = lambda flag : print("Yes" if flag else "No")def acc(a:list[int]):sa = [0]*(len(a)+1)for i in range(len(a)):sa[i+1] = a[i] + sa[i]return saprinf = lambda ans : print(ans if ans < 1000001001001001001 else -1)alplow = "abcdefghijklmnopqrstuvwxyz"alpup = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"alpall = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"URDL = {'U':(-1,0), 'R':(0,1), 'D':(1,0), 'L':(0,-1)}DIR_4 = [[-1,0],[0,1],[1,0],[0,-1]]DIR_8 = [[-1,0],[-1,1],[0,1],[1,1],[1,0],[1,-1],[0,-1],[-1,-1]]DIR_BISHOP = [[-1,1],[1,1],[1,-1],[-1,-1]]prime60 = [2,3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59]sys.set_int_max_str_digits(0)# sys.setrecursionlimit(10**6)# import pypyjit# pypyjit.set_param('max_unroll_recursion=-1')from collections import defaultdictfrom heapq import heappop,heappushfrom bisect import bisect_left,bisect_rightDD = defaultdictBSL = bisect_leftBSR = bisect_rightMOD = 998244353_IMAG = 911660635_IIMAG = 86583718_rate2 = (0, 911660635, 509520358, 369330050, 332049552, 983190778, 123842337, 238493703, 975955924, 603855026, 856644456, 131300601, 842657263,730768835, 942482514, 806263778, 151565301, 510815449, 503497456, 743006876, 741047443, 56250497, 867605899, 0)_irate2 = (0, 86583718, 372528824, 373294451, 645684063, 112220581, 692852209, 155456985, 797128860, 90816748, 860285882, 927414960, 354738543,109331171, 293255632, 535113200, 308540755, 121186627, 608385704, 438932459, 359477183, 824071951, 103369235, 0)_rate3 = (0, 372528824, 337190230, 454590761, 816400692, 578227951, 180142363, 83780245, 6597683, 70046822, 623238099, 183021267, 402682409,631680428, 344509872, 689220186, 365017329, 774342554, 729444058, 102986190, 128751033, 395565204, 0)_irate3 = (0, 509520358, 929031873, 170256584, 839780419, 282974284, 395914482, 444904435, 72135471, 638914820, 66769500, 771127074, 985925487,262319669, 262341272, 625870173, 768022760, 859816005, 914661783, 430819711, 272774365, 530924681, 0)def _fft(a):n = len(a)h = (n - 1).bit_length()le = 0for le in range(0, h - 1, 2):p = 1 << (h - le - 2)rot = 1for s in range(1 << le):rot2 = rot * rot % MODrot3 = rot2 * rot % MODoffset = s << (h - le)for i in range(p):a0 = a[i + offset]a1 = a[i + offset + p] * rota2 = a[i + offset + p * 2] * rot2a3 = a[i + offset + p * 3] * rot3a1na3imag = (a1 - a3) % MOD * _IMAGa[i + offset] = (a0 + a2 + a1 + a3) % MODa[i + offset + p] = (a0 + a2 - a1 - a3) % MODa[i + offset + p * 2] = (a0 - a2 + a1na3imag) % MODa[i + offset + p * 3] = (a0 - a2 - a1na3imag) % MODrot = rot * _rate3[(~s & -~s).bit_length()] % MODif h - le & 1:rot = 1for s in range(1 << (h - 1)):offset = s << 1l = a[offset]r = a[offset + 1] * rota[offset] = (l + r) % MODa[offset + 1] = (l - r) % MODrot = rot * _rate2[(~s & -~s).bit_length()] % MODdef _ifft(a):n = len(a)h = (n - 1).bit_length()le = hfor le in range(h, 1, -2):p = 1 << (h - le)irot = 1for s in range(1 << (le - 2)):irot2 = irot * irot % MODirot3 = irot2 * irot % MODoffset = s << (h - le + 2)for i in range(p):a0 = a[i + offset]a1 = a[i + offset + p]a2 = a[i + offset + p * 2]a3 = a[i + offset + p * 3]a2na3iimag = (a2 - a3) * _IIMAG % MODa[i + offset] = (a0 + a1 + a2 + a3) % MODa[i + offset + p] = (a0 - a1 + a2na3iimag) * irot % MODa[i + offset + p * 2] = (a0 + a1 - a2 - a3) * irot2 % MODa[i + offset + p * 3] = (a0 - a1 - a2na3iimag) * irot3 % MODirot = irot * _irate3[(~s & -~s).bit_length()] % MODif le & 1:p = 1 << (h - 1)for i in range(p):l = a[i]r = a[i + p]a[i] = l + r if l + r < MOD else l + r - MODa[i + p] = l - r if l - r >= 0 else l - r + MODdef ntt(a):if len(a) <= 1: return_fft(a)def intt(a):if len(a) <= 1: return_ifft(a)iv = pow(len(a), MOD - 2, MOD)for i, x in enumerate(a): a[i] = x * iv % MODdef multiply(s: list, t: list) -> list:n, m = len(s), len(t)l = n + m - 1if min(n, m) <= 60:a = [0] * lfor i, x in enumerate(s):for j, y in enumerate(t):a[i + j] += x * yreturn [x % MOD for x in a]z = 1 << (l - 1).bit_length()a = s + [0] * (z - n)b = t + [0] * (z - m)_fft(a)_fft(b)for i, x in enumerate(b): a[i] = a[i] * x % MOD_ifft(a)a[l:] = []iz = pow(z, MOD - 2, MOD)return [x * iz % MOD for x in a]def pow2(s: list) -> list:n = len(s)l = (n << 1) - 1if n <= 60:a = [0] * lfor i, x in enumerate(s):for j, y in enumerate(s):a[i + j] += x * yreturn [x % MOD for x in a]z = 1 << (l - 1).bit_length()a = s + [0] * (z - n)_fft(a)for i, x in enumerate(a): a[i] = x * x % MOD_ifft(a)a[l:] = []iz = pow(z, MOD - 2, MOD)return [x * iz % MOD for x in a]def ntt_doubling(a: list) -> None:M = len(a)b = a[:]intt(b)r = 1zeta = pow(3, (MOD - 1) // (M << 1), MOD)for i, x in enumerate(b):b[i] = x * r % MODr = r * zeta % MODntt(b)a += bdef mod_sqrt(a: int, p: int):'x s.t. x**2 == a (mod p) if exist else -1'if a < 2: return aif pow(a, (p - 1) >> 1, p) != 1: return -1b = 1while pow(b, (p - 1) >> 1, p) == 1: b += 1m = p - 1; e = 0while not m & 1:m >>= 1e += 1x = pow(a, (m - 1) >> 1, p)y = (a * x % p) * x % px = a * x % pz = pow(b, m, p)while y != 1:j = 0t = ywhile t != 1:j += 1t = t * t % pz = pow(z, 1 << (e - j - 1), p)x = x * z % pz = z * z % py = y * z % pe = jreturn xfrom math import log2# https://nyaannyaan.github.io/library/fps/formal-power-series.hppdef fps_add(a: list, b: list) -> list:if len(a) < len(b):res = b[::]for i, x in enumerate(a): res[i] += xelse:res = a[::]for i, x in enumerate(b): res[i] += xreturn [x % MOD for x in res]def fps_add_scalar(a: list, k: int) -> list:res = a[:]res[0] = (res[0] + k) % MODreturn resdef fps_sub(a: list, b: list) -> list:if len(a) < len(b):res = b[::]for i, x in enumerate(a): res[i] -= xres = fps_neg(res)else:res = a[::]for i, x in enumerate(b): res[i] -= xreturn [x % MOD for x in res]def fps_sub_scalar(a: list, k: int) -> list:return fps_add_scalar(a, -k)def fps_neg(a: list) -> list:return [MOD - x if x else 0 for x in a]def fps_mul_scalar(a: list, k: int) -> list:return [x * k % MOD for x in a]def fps_matmul(a: list, b: list) -> list:'not verified'return [x * b[i] % MOD for i, x in enumerate(a)]def fps_div(a: list, b: list) -> list:if len(a) < len(b): return []n = len(a) - len(b) + 1cnt = 0if len(b) > 64:return multiply(a[::-1][:n], fps_inv(b[::-1], n))[:n][::-1]f, g = a[::], b[::]while g and not g[-1]:g.pop()cnt += 1coef = pow(g[-1], MOD - 2, MOD)g = fps_mul_scalar(g, coef)deg = len(f) - len(g) + 1gs = len(g)quo = [0] * degfor i in range(deg)[::-1]:quo[i] = x = f[i + gs - 1] % MODfor j, y in enumerate(g):f[i + j] -= x * yreturn fps_mul_scalar(quo, coef) + [0] * cntdef fps_mod(a: list, b: list) -> list:res = fps_sub(a, multiply(fps_div(a, b), b))while res and not res[-1]: res.pop()return resdef fps_divmod(a: list, b: list):q = fps_div(a, b)r = fps_sub(a, multiply(q, b))while r and not r[-1]: r.pop()return q, rdef fps_eval(a: list, x: int) -> int:r = 0; w = 1for v in a:r += w * v % MODw = w * x % MODreturn r % MODdef fps_inv(a: list, deg: int=-1) -> list:# assert(self[0] != 0)if deg == -1: deg = len(a)res = [0] * degres[0] = pow(a[0], MOD - 2, MOD)d = 1while d < deg:f = [0] * (d << 1)tmp = min(len(a), d << 1)f[:tmp] = a[:tmp]g = [0] * (d << 1)g[:d] = res[:d]ntt(f)ntt(g)for i, x in enumerate(g): f[i] = f[i] * x % MODintt(f)f[:d] = [0] * dntt(f)for i, x in enumerate(g): f[i] = f[i] * x % MODintt(f)for j in range(d, min(d << 1, deg)):if f[j]: res[j] = MOD - f[j]else: res[j] = 0d <<= 1return resdef fps_pow(a: list, k: int, deg=-1) -> list:n = len(a)if deg == -1: deg = nif k == 0:if not deg: return []ret = [0] * degret[0] = 1return retfor i, x in enumerate(a):if x:rev = pow(x, MOD - 2, MOD)ret = fps_mul_scalar(fps_exp(fps_mul_scalar(fps_log(fps_mul_scalar(a, rev)[i:], deg), k), deg), pow(x, k, MOD))ret[:0] = [0] * (i * k)if len(ret) < deg:ret[len(ret):] = [0] * (deg - len(ret))return retreturn ret[:deg]if (i + 1) * k >= deg: breakreturn [0] * degdef fps_exp(a: list, deg=-1) -> list:# assert(not self or self[0] == 0)if deg == -1: deg = len(a)inv = [0, 1]def inplace_integral(F: list) -> list:n = len(F)while len(inv) <= n:j, k = divmod(MOD, len(inv))inv.append((-inv[k] * j) % MOD)return [0] + [x * inv[i + 1] % MOD for i, x in enumerate(F)]def inplace_diff(F: list) -> list:return [x * i % MOD for i, x in enumerate(F) if i]b = [1, (a[1] if 1 < len(a) else 0)]c = [1]z1 = []z2 = [1, 1]m = 2while m < deg:y = b + [0] * mntt(y)z1 = z2z = [y[i] * p % MOD for i, p in enumerate(z1)]intt(z)z[:m >> 1] = [0] * (m >> 1)ntt(z)for i, p in enumerate(z1): z[i] = z[i] * (-p) % MODintt(z)c[m >> 1:] = z[m >> 1:]z2 = c + [0] * mntt(z2)tmp = min(len(a), m)x = a[:tmp] + [0] * (m - tmp)x = inplace_diff(x)x.append(0)ntt(x)for i, p in enumerate(x): x[i] = y[i] * p % MODintt(x)for i, p in enumerate(b):if not i: continuex[i - 1] -= p * i % MODx += [0] * mfor i in range(m - 1): x[m + i], x[i] = x[i], 0ntt(x)for i, p in enumerate(z2): x[i] = x[i] * p % MODintt(x)x.pop()x = inplace_integral(x)x[:m] = [0] * mfor i in range(m, min(len(a), m << 1)): x[i] += a[i]ntt(x)for i, p in enumerate(y): x[i] = x[i] * p % MODintt(x)b[m:] = x[m:]m <<= 1return b[:deg]def fps_log(a: list, deg=-1) -> list:# assert(a[0] == 1)if deg == -1: deg = len(a)return fps_integral(multiply(fps_diff(a), fps_inv(a, deg))[:deg - 1])def fps_integral(a: list) -> list:n = len(a)res = [0] * (n + 1)if n: res[1] = 1for i in range(2, n + 1):j, k = divmod(MOD, i)res[i] = (-res[k] * j) % MODfor i, x in enumerate(a): res[i + 1] = res[i + 1] * x % MODreturn resdef fps_diff(a: list) -> list:return [i * x % MOD for i, x in enumerate(a) if i]def shrink(a: list) -> None:while a and not a[-1]: a.pop()class Mat:def __init__(self, a00: list, a01: list, a10: list, a11: list) -> None:self.arr = [a00, a01, a10, a11]def __mul__(self, r):a00, a01, a10, a11 = self.arrif type(r) is Mat:ra00, ra01, ra10, ra11 = r.arrA00 = fps_add(multiply(a00, ra00), multiply(a01, ra10))A01 = fps_add(multiply(a00, ra01), multiply(a01, ra11))A10 = fps_add(multiply(a10, ra00), multiply(a11, ra10))A11 = fps_add(multiply(a10, ra01), multiply(a11, ra11))shrink(A00)shrink(A01)shrink(A10)shrink(A11)return Mat(A00, A01, A10, A11)b0 = fps_add(multiply(a00, r[0]), multiply(a01, r[1]))b1 = fps_add(multiply(a10, r[0]), multiply(a11, r[1]))shrink(b0)shrink(b1)return [b0, b1]@staticmethoddef I(): return Mat([1], [], [], [1])def inner_naive_gcd(m: Mat, p: list) -> None:quo, rem = fps_divmod(p[0], p[1])b10 = fps_sub(m.arr[0], multiply(m.arr[2], quo))b11 = fps_sub(m.arr[1], multiply(m.arr[3], quo))shrink(rem)shrink(b10)shrink(b11)m.arr = [m.arr[2], m.arr[3], b10, b11]p[0], p[1] = p[1], remdef inner_half_gcd(p: list) -> Mat:n = len(p[0]); m = len(p[1])k = n + 1 >> 1if m <= k: return Mat.I()m1 = inner_half_gcd([p[0][k:], p[1][k:]])p = m1 * pif len(p[1]) <= k: return m1inner_naive_gcd(m1, p)if len(p[1]) <= k: return m1l = len(p[0]) - 1j = 2 * k - lp[0] = p[0][j:]p[1] = p[1][j:]return inner_half_gcd(p) * m1def inner_poly_gcd(a: list, b: list) -> Mat:p = [a[::], b[::]]shrink(p[0]); shrink(p[1])n = len(p[0]); m = len(p[1])if n < m:mat = inner_poly_gcd(p[1], p[0])mat.arr = [mat.arr[1], mat.arr[0], mat.arr[2], mat.arr[3]]return matres = Mat.I()while 1:m1 = inner_half_gcd(p)p = m1 * pif not p[1]: return m1 * resinner_naive_gcd(m1, p)if not p[1]: return m1 * resres = m1 * resdef poly_gcd(a: list, b: list) -> list:p = [a, b]m = inner_poly_gcd(a, b)p = m * pif p[0]:coef = pow(p[0][-1], MOD - 2, MOD)for i, x in enumerate(p[0]): p[0][i] = x * coef % MODreturn p[0]def poly_inv(f: list, g: list) -> list:p = [f, g]m = inner_poly_gcd(f, g)gcd = (m * p)[0]if len(gcd) != 1: return [0, []]x = [[1], g]return [1, fps_mul_scalar(fps_mod((m * x)[0], g), pow(gcd[0], MOD - 2, MOD))]def LinearRecurrence(n: int, p: list, q: list):"""[x^n]P(x)/Q(x) を求めるdeg(p) < deg(q)が必要"""# assert len(p) < len(q)shrink(q)while n:q2 = q[:]for i in range(1,len(q2),2): q2[i] = (-q2[i])%MODs = multiply(p,q2)t = multiply(q,q2)for i in range(n&1,len(s),2): p[i>>1] = s[i]for i in range(0,len(t),2): q[i>>1] = t[i]n >>= 1return p[0]%MODdef Bostan_Mori(n: int, a: list, c: list):"""k 項間漸化式を求めるaが初項、cが漸化式の係数"""# assert c[0] != 0k = len(c)if n < len(a):return a[n]c = [1] + [(-i)%MOD for i in c]p = multiply(a,c)[:k-1]return LinearRecurrence(n,p,c)n,t = MI()k,l = MI()inv6 = pow(6, -1, mod)a = [0]*ta[0] = 1c = [0]*(t+1)c[1] = inv6*(k-1)%modc[2] = inv6*(l-k)%modc[t] = inv6*(6-l+1)%modfor i in range(1,t):a[i] += a[i-1] * c[1]if i >= 2:a[i] += a[i-2] * c[2]if i >= t:a[i] += a[i-t] * c[t]a[i] %= mod# print(a)res = Bostan_Mori(n-1, a, c) % modprint(res)