結果
問題 | No.1962 Not Divide |
ユーザー |
|
提出日時 | 2022-05-27 23:07:38 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 413 ms / 2,000 ms |
コード長 | 14,357 bytes |
コンパイル時間 | 277 ms |
コンパイル使用メモリ | 82,488 KB |
実行使用メモリ | 87,656 KB |
最終ジャッジ日時 | 2024-09-20 16:19:12 |
合計ジャッジ時間 | 5,891 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 21 |
ソースコード
MOD = 998244353sum_e = (911660635, 509520358, 369330050, 332049552, 983190778, 123842337, 238493703, 975955924, 603855026, 856644456, 131300601, 842657263,730768835, 942482514, 806263778, 151565301, 510815449, 503497456, 743006876, 741047443, 56250497, 0, 0, 0, 0, 0, 0, 0, 0, 0)sum_ie = (86583718, 372528824, 373294451, 645684063, 112220581, 692852209, 155456985, 797128860, 90816748, 860285882, 927414960, 354738543, 109331171, 293255632, 535113200, 308540755, 121186627, 608385704, 438932459, 359477183, 824071951, 0, 0, 0, 0, 0, 0, 0, 0, 0)def butterfly(arr):n = len(arr)h = (n - 1).bit_length()for ph in range(1, h + 1):w = 1 << (ph - 1)p = 1 << (h - ph)now = 1for s in range(w):offset = s << (h - ph + 1)for i in range(p):l = arr[i + offset]r = arr[i + offset + p] * nowarr[i + offset] = (l + r) % MODarr[i + offset + p] = (l - r) % MODnow *= sum_e[(~s & -~s).bit_length() - 1]now %= MODdef butterfly_inv(arr):n = len(arr)h = (n - 1).bit_length()for ph in range(1, h + 1)[::-1]:w = 1 << (ph - 1)p = 1 << (h - ph)inow = 1for s in range(w):offset = s << (h - ph + 1)for i in range(p):l = arr[i + offset]r = arr[i + offset + p]arr[i + offset] = (l + r) % MODarr[i + offset + p] = (MOD + l - r) * inow % MODinow *= sum_ie[(~s & -~s).bit_length() - 1]inow %= MODinv = pow(n, MOD - 2, MOD)for i in range(n):arr[i] *= invarr[i] %= MODdef build_exp(n, b):exp = [0] * (n + 1)exp[0] = 1for i in range(n):exp[i + 1] = exp[i] * b % MODreturn expdef build_factorial(n):fct = [0] * (n + 1)inv = [0] * (n + 1)fct[0] = inv[0] = 1for i in range(n):fct[i + 1] = fct[i] * (i + 1) % MODinv[n] = pow(fct[n], MOD - 2, MOD)for i in range(n)[::-1]:inv[i] = inv[i + 1] * (i + 1) % MODreturn fct, invdef sqrt_mod(n):if n == 0: return 0if n == 1: return 1h = (MOD - 1) // 2if pow(n, h, MOD) != 1: return -1q, s = MOD - 1, 0while not q & 1:q >>= 1s += 1z = 1while pow(z, h, MOD) != MOD - 1:z += 1m, c, t, r = s, pow(z, q, MOD), pow(n, q, MOD), pow(n, (q + 1) // 2, MOD)while t != 1:k = 1while pow(t, 1 << k, MOD) != 1:k += 1x = pow(c, pow(2, m - k - 1, MOD - 1), MOD)m = kc = (x * x) % MODt = (t * c) % MODr = (r * x) % MODif r * r % MOD != n: return -1return rclass FormalPowerSeries():def __init__(self, arr=None):if arr is None: arr = []self.arr = [v % MOD for v in arr]def __len__(self):return len(self.arr)def __getitem__(self, key):if isinstance(key, slice):return FormalPowerSeries(self.arr[key])else:assert key >= 0if key >= len(self):return 0else:return self.arr[key]def __setitem__(self, key, val):assert key >= 0if key >= len(self):self.arr += [0] * (key - len(self) + 1)self.arr[key] = val % MODdef __str__(self):return ' '.join(map(str, self.arr))def resize(self, sz):assert sz >= 0if len(self) >= sz:return self[:sz]else:return FormalPowerSeries(self.arr + [0] * (sz - len(self)))def shrink(self):while self.arr and not self.arr[-1]:self.arr.pop()def times(self, k):return FormalPowerSeries([v * k for v in self.arr])def __pos__(self):return selfdef __neg__(self):return self.times(-1)def __add__(self, other):if other.__class__ == FormalPowerSeries:n = len(self)m = len(other)arr = [self[i] + other[i] for i in range(min(n, m))]if n >= m:arr += self.arr[m:]else:arr += other.arr[n:]return FormalPowerSeries(arr)else:return self + FormalPowerSeries([other])def __iadd__(self, other):if other.__class__ == FormalPowerSeries:n = len(self)m = len(other)for i in range(min(n, m)):self.arr[i] += other[i]self.arr[i] %= MODif n < m:self.arr += other.arr[n:]else:self.arr[0] += otherself.arr[0] %= MODreturn selfdef __radd__(self, other):return self + otherdef __sub__(self, other):return self + (-other)def __isub__(self, other):self += -otherreturn selfdef __rsub__(self, other):return (-self) + otherdef __mul__(self, other):if other.__class__ == FormalPowerSeries:f = self.arr.copy()g = other.arr.copy()n = len(f)m = len(g)if not n or not m: return FormalPowerSeries()if min(n, m) <= 100:if n < m: f, n, g, m = g, m, f, narr = [0] * (n + m - 1)for i in range(n):for j in range(m):arr[i + j] += f[i] * g[j]arr[i + j] %= MODreturn FormalPowerSeries(arr)z = 1 << (n + m - 2).bit_length()f += [0] * (z - n)g += [0] * (z - m)butterfly(f)butterfly(g)for i in range(z):f[i] *= g[i]f[i] %= MODbutterfly_inv(f)f = f[:n + m - 1]return FormalPowerSeries(f)else:return self.times(other)def __matmul__(self, other):assert other.__class__ == FormalPowerSeriesn = max(len(self), len(other))res = (self * other).resize(n)return resdef __imul__(self, other):if other.__class__ == FormalPowerSeries:f = self.arr.copy()g = other.arr.copy()n = len(f)m = len(g)if not n or not m: return FormalPowerSeries()if min(n, m) <= 100:if n < m: f, n, g, m = g, m, f, narr = [0] * (n + m - 1)for i in range(n):for j in range(m):arr[i + j] += f[i] * g[j]arr[i + j] %= MODself.arr = arrreturn selfz = 1 << (n + m - 2).bit_length()f += [0] * (z - n)g += [0] * (z - m)butterfly(f)butterfly(g)for i in range(z):f[i] *= g[i]f[i] %= MODbutterfly_inv(f)self.arr = f[:n + m - 1]return selfelse:n = len(self)for i in range(n):self.arr[i] *= otherself.arr[i] %= MODreturn selfdef __rmul__(self, other):return self.times(other)def __pow__(self, k): #exp書いたら修正n = len(self)tmp = FormalPowerSeries(self.arr)res = FormalPowerSeries([1])while k:if k & 1:res *= tmpres = res.resize(n)tmp *= tmptmp = tmp.resize(n)k >>= 1return resdef square(self):f = self.arr.copy()n = len(f)if not n: return FormalPowerSeries()if n <= 100:arr = [0] * (2 * n - 1)for i in range(n):for j in range(n):arr[i + j] += f[i] * f[j]arr[i + j] %= MODreturn FormalPowerSeries(arr)z = 1 << (2 * n - 2).bit_length()f += [0] * (z - n)butterfly(f)for i in range(z):f[i] *= f[i]f[i] %= MODbutterfly_inv(f)f = f[:2 * n - 1]return FormalPowerSeries(f)def __lshift__(self, key):assert key >= 0return FormalPowerSeries([0] * key + self.arr)def __rshift__(self, key):assert key >= 0return self[key:]def __invert__(self):assert self[0] != 0n = len(self)r = pow(self[0], MOD - 2, MOD)m = 1res = FormalPowerSeries([r])while m < n:f = [0] * (2 * m)g = [0] * (2 * m)for i in range(2 * m):f[i] = self[i]for i in range(m):g[i] = res[i]butterfly(f)butterfly(g)for i in range(2 * m):f[i] *= g[i]f[i] %= MODbutterfly_inv(f)for i in range(m):f[i] = 0butterfly(f)for i in range(2 * m):f[i] *= g[i]f[i] %= MODbutterfly_inv(f)for i in range(m, 2 * m):res[i] -= f[i]m <<= 1return res.resize(n)def __truediv__(self, other):if other.__class__ == FormalPowerSeries:return self * ~otherelse:return self * pow(other, MOD - 2, MOD)def __rtruediv__(self, other):return other * ~selfdef differentiate(self):n = len(self)arr = [0] * nfor i in range(1, n):arr[i - 1] = self[i] * i % MODreturn FormalPowerSeries(arr)def integrate(self):n = len(self)arr = [0] * nfor i in range(n - 1):arr[i + 1] = self[i] * pow(i + 1, MOD - 2, MOD) % MODreturn FormalPowerSeries(arr)def log(self):assert self[0] == 1n = len(self)return (self.differentiate() / self).integrate().resize(n)def __floordiv__(self, other):if other.__class__ == FormalPowerSeries:n = len(self)m = len(other)if n < m: return FormalPowerSeries()l = n - m + 1if m <= 100:arr = [0] * linv = pow(other[m - 1], MOD - 2, MOD)tmp = self[::-1]for i in range(l):arr[i] = tmp[i] * inv % MODfor j in range(m):tmp[i + j] -= other[m - j - 1] * arr[i]tmp[i + j] %= MODreturn FormalPowerSeries(arr[::-1])res = (self[~l:][::-1] * ~(other[::-1].resize(l))).resize(l)[::-1]return reselse:return self * pow(other, MOD - 2, MOD)def __rfloordiv__(self, other):return other * ~selfdef __mod__(self, other):n = len(self)m = len(other)if n < m: return FormalPowerSeries(self.arr)res = self[:m - 1] - ((self // other) * other)[:m - 1]return resdef multipoint_evaluation(self, xs):n = len(xs)sz = 1 << (n - 1).bit_length()g = [FormalPowerSeries([1]) for _ in range(2 * sz)]for i in range(n):g[i + sz] = FormalPowerSeries([-xs[i], 1])for i in range(1, sz)[::-1]:g[i] = g[2 * i] * g[2 * i + 1]g[1] = self % g[1]for i in range(2, 2 * sz):g[i] = g[i >> 1] % g[i]res = [g[i + sz][0] for i in range(n)]return resdef polynomial_interpolation(xs, ys):assert len(xs) == len(ys)n = len(xs)sz = 1 << (n - 1).bit_length()f = [FormalPowerSeries([1]) for _ in range(2 * sz)]for i in range(n):f[i + sz] = FormalPowerSeries([-xs[i], 1])for i in range(1, sz)[::-1]:f[i] = f[2 * i] * f[2 * i + 1]g = [FormalPowerSeries([0])] * (2 * sz)g[1] = f[1].differentiate() % f[1]for i in range(2, n + sz):g[i] = g[i >> 1] % f[i]for i in range(n):g[i + sz] = FormalPowerSeries([ys[i] * pow(g[i + sz][0], MOD - 2, MOD) % MOD])for i in range(1, sz)[::-1]:g[i] = g[2 * i] * f[2 * i + 1] + g[2 * i + 1] * f[2 * i]return g[1][:n]def berlekamp_massey(arr):if arr.__class__ == FormalPowerSeries:arr = arr.arrn = len(arr)b = [1]c = [1]l, m, p = 0, 0, 1for i in range(n):m += 1d = arr[i]for j in range(1, l + 1):d += c[j] * arr[i - j]d %= MODif d == 0: continuet = c.copy()q = d * pow(p, MOD - 2, MOD) % MODif len(c) < len(b) + m:c += [0] * (len(b) + m - len(c))for j in range(len(b)):c[j + m] -= q * b[j]c[j + m] %= MODif 2 * l <= i:b = tl, m, p = i + 1 - l, 0, dreturn cdef linear_recurrence(arr, coeff, k):d = len(arr)q = FormalPowerSeries(coeff)p = arr.resize(d)while k:r = [-q[i] if i & 1 else q[i] for i in range(len(q))] + [0] * (d + 1 - len(q))r = FormalPowerSeries(r)p *= rq *= rp = p[(k & 1)::2]q = q[::2]k >>= 1return p[0] % MOD#print(linear_recurrence([1],[1,-2,1],0))class SegmentTree:def __init__(self, init_val, segfunc, ide_ele):n = len(init_val)self.segfunc = segfuncself.ide_ele = ide_eleself.num = 1 << (n - 1).bit_length()self.tree = [ide_ele] * 2 * self.numself.size = nfor i in range(n):self.tree[self.num + i] = init_val[i]for i in range(self.num - 1, 0, -1):self.tree[i] = self.segfunc(self.tree[2 * i], self.tree[2 * i + 1])N,M = map(int,input().split())init = []for i in range(1,M+1):P = FormalPowerSeries([0]+[1]*(i-1))Q = FormalPowerSeries([1] + [1] * (i-1) + [MOD-1])init.append([P,Q])def merge(x,y):return [x[0]*y[1]+y[0]*x[1],x[1]*y[1]]seg = SegmentTree(init,merge,[FormalPowerSeries([0]),FormalPowerSeries([1])])P,Q = seg.tree[1]Q -= Pans = linear_recurrence(P,Q.arr,N)print(ans)