MOD = 998244353 PRIMITIVE_ROOT = 3 def modinv(x): return pow(x, MOD-2, MOD) def ntt(a, invert): n = len(a) j = 0 for i in range(1,n): bit = n >> 1 while j & bit: j ^= bit bit >>= 1 j ^= bit if i < j: a[i],a[j] = a[j],a[i] length = 2 while length <= n: wlen = pow(PRIMITIVE_ROOT,(MOD-1)//length,MOD) if invert: wlen = modinv(wlen) for i in range(0,n,length): w = 1 for j in range(i,i+length//2): u = a[j] v = a[j+length//2]*w % MOD a[j] = (u+v) % MOD a[j+length//2] = (u-v+MOD)%MOD w = w*wlen % MOD length <<= 1 if invert: ninv = modinv(n) for i in range(n): a[i] = a[i]*ninv % MOD def convolution(a,b): n = 1 while n < len(a)+len(b)-1: n <<= 1 fa = a+[0]*(n-len(a)) fb = b+[0]*(n-len(b)) ntt(fa,False) ntt(fb,False) for i in range(n): fa[i] = fa[i]*fb[i] % MOD ntt(fa,True) return fa[:len(a)+len(b)-1]