from __future__ import annotations input class CombLUT: MOD = 998_244_353 _max = 0 F = [1] InvF = [1] _fibo = [0, 1] _lucas = [2, 1] @classmethod def expand(cls, n): if n <= cls._max: return n = n*4//3 cls.F.extend([1]*(n - cls._max)) for i in range(cls._max + 1, n+1): cls.F[i] = cls.F[i-1] * i % cls.MOD cls.InvF.extend([1]*(n - cls._max)) cls.InvF[n] = pow(cls.F[n], -1, cls.MOD) for i in range(n-1, cls._max, -1): cls.InvF[i] = cls.InvF[i+1] * (i+1) % cls.MOD cls._max = n @classmethod def comb(cls, n, r): if not (0 <= r <= n): return 0 if n > cls._max: cls.expand(n) return cls.F[n] * cls.InvF[r] % cls.MOD * cls.InvF[n-r] % cls.MOD @classmethod def hyper_comb(cls, n, r): if (r < 0): return 0 if (n < 0): return 0 if n == r == 0: return 1 if not (0 <= r <= n-1+r): return 0 return cls.comb(n-1+r, r) @classmethod def neg_hyper_comb(cls, n, r): if (r < 0): return 0 if n >= 0: return cls.hyper_comb(n, r) sgn = -1 if r&1 else 1 return sgn * cls.comb(-n, r) @classmethod def factorial(cls, n): assert 0 <= n if n > cls._max: cls.expand(n) return cls.F[n] @classmethod def inv_factorial(cls, n): assert 0 <= n if n > cls._max: cls.expand(n) return cls.InvF[n] def subset_sum(A, *, func=None, initial=None): if (func is None): if initial is None: initial = 0 N = len(A) dp = [initial] * (1< None: """ Sets a new modulus and recalculates the associated parameters. Args: mod (int): The new prime modulus. """ cls._mod = mod cls._root = PrimeFactor.primitive_root(mod) cls._rank2 = tzcount(mod-1) root = [0]*(cls._rank2 + 1) iroot = [0]*(cls._rank2 + 1) root[cls._rank2] = pow(cls._root, (mod-1)>>cls._rank2, mod) iroot[cls._rank2] = pow(root[cls._rank2], mod-2, mod) for i in range(cls._rank2 - 1, -1, -1): root[i] = root[i+1] * root[i+1] % mod iroot[i] = iroot[i+1] * iroot[i+1] % mod cls._imag = root[2] cls._iimag = iroot[2] cls._rate2, cls._irate2 = cls.__calculate_rates(root, iroot, 2) cls._rate3, cls._irate3 = cls.__calculate_rates(root, iroot, 3) @classmethod def __calculate_rates(cls, root: list, iroot: list, ofs: int) -> tuple: """ Calculates the rates used in the butterfly transformations. Args: root (list): List of roots. iroot (list): List of inverse roots. ofs (int): Offset to start calculation. Returns: tuple: A tuple containing two lists: rates and inverse rates. """ rate = [0]*max(0, cls._rank2 - (ofs-1)) irate = [0]*max(0, cls._rank2 - (ofs-1)) prod, iprod = 1, 1 for i in range(cls._rank2 - (ofs-1)): rate[i] = root[i + ofs] * prod % cls._mod irate[i] = iroot[i + ofs] * iprod % cls._mod prod *= iroot[i + ofs] prod %= cls._mod iprod *= root[i + ofs] iprod %= cls._mod return rate, irate @classmethod def butterfly(cls, a: list[int]) -> None: """ Applies the butterfly transformation on the input list. Args: a (list[int]): The input list. """ n = len(a) h = (n-1).bit_length() for len_ in range(0, h-1, 2): p = 1<<(h - len_ - 2) rot = 1 for s in range(1< None: """ Applies the inverse butterfly transformation on the input list. Args: a (list[int]): The input list. """ n = len(a) h = (n-1).bit_length() for len_ in range(h, 1, -2): p = 1<<(h - len_) irot = 1 for s in range(1<<(len_-2)): irot2 = irot*irot%cls._mod irot3 = irot2*irot%cls._mod offset = s<<(h - len_ + 2) for i in range(p): a0 = a[i + offset] a1 = a[i + offset + p] a2 = a[i + offset + p*2] a3 = a[i + offset + p*3] a2na3iimag = (a2-a3) * cls._iimag % cls._mod a[i + offset] = (a0+a1+a2+a3) % cls._mod a[i + offset + p] = (a0-a1 + a2na3iimag) * irot % cls._mod a[i + offset + p*2] = (a0+a1-a2-a3) * irot2 % cls._mod a[i + offset + p*3] = (a0-a1 - a2na3iimag) * irot3 % cls._mod if s+1 != (1<<(len_-2)): irot *= cls._irate3[(~s& -~s).bit_length() - 1] irot %= cls._mod if h&1: p = 1<<(h-1) for i in range(p): l = a[i] r = a[i+p] a[i] = (l+r) % cls._mod a[i+p] = (l-r) % cls._mod @classmethod def convolution(cls, a: list[int], b: list[int]) -> list[int]: """ Computes the convolution of two lists. Args: a (list[int]): First input list. b (list[int]): Second input list. Returns: list[int]: The convolution of a and b. Raises: ValueError: If the length of the result exceeds the supported length. """ n, m = len(a), len(b) if n+m-1 > (1<>1)&0x5555_5555_5555_5555) x = (x&0x3333_3333_3333_3333) + ((x>>2)&0x3333_3333_3333_3333) x = (x + (x>>4))&0x0f0f_0f0f_0f0f_0f0f x += (x>>8) x += (x>>16) x += (x>>32) return x&0x0000_0000_0000_007f def solve(N, K, A): def resolve(k, a): lamda = Fraction(1, a) coef = [Fraction(0)] * k pows = powers(lamda, k+10, MOD=998_244_353) for n in range(k): # coef[n] = pow(lamda, n, MOD) * CombLUT.inv_factorial(n) coef[n] = pows[n] * CombLUT.inv_factorial(n) % MOD return ExpPoly(lamda, coef) F = [resolve(k, a) for k, a in zip(K, A)] dp = subset_sum(F, func=lambda x, y: x*y, initial=1) I = [f.integral() if isinstance(f, ExpPoly) else 0 for f in dp] popcnt = [Fraction(0)] * (N+1) for S in range(1<