MOD = 998244353 PRIMITIVE_ROOT = 3 def ntt(a, invert): n = len(a) j = 0 for i in range(1, n): bit = n >> 1 while j & bit: j ^= bit bit >>= 1 j ^= bit if i < j: a[i], a[j] = a[j], a[i] len_ = 2 while len_ <= n: wlen = pow(PRIMITIVE_ROOT, (MOD - 1) // len_, MOD) if invert: wlen = pow(wlen, MOD - 2, MOD) for i in range(0, n, len_): w = 1 for j in range(len_ // 2): u = a[i + j] v = a[i + j + len_ // 2] * w % MOD a[i + j] = (u + v) % MOD a[i + j + len_ // 2] = (u - v + MOD) % MOD w = w * wlen % MOD len_ <<= 1 if invert: inv_n = pow(n, MOD - 2, MOD) for i in range(n): a[i] = a[i] * inv_n % MOD def conv_ntt(a, b): n = 1 while n < len(a) + len(b): n <<= 1 fa = a + [0] * (n - len(a)) fb = b + [0] * (n - len(b)) ntt(fa, False) ntt(fb, False) for i in range(n): fa[i] = fa[i] * fb[i] % MOD ntt(fa, True) return fa[:len(a) + len(b) - 1] n,m = map(int, input().split()) f = [1, 1] for i in range(n): fx = f.copy() fy = f.copy() fy[0] = 0 f = conv_ntt(fx, fy) f[0] = 1 ans = f[2**n-m] for i in range(1, m+1): ans = ans * i % MOD for i in range(1, 2**n - m + 1): ans = ans * i % MOD print(ans)