# 二項係数 mod = 998244353 fac = [1, 1] finv = [1, 1] inv = [0, 1] def init(n): for i in range(2, n + 1): fac.append(fac[-1] * i % mod) inv.append(-inv[mod % i] * (mod // i) % mod) finv.append(finv[-1] * inv[-1] % mod) def com(n, k, mod): if n < 0 or k < 0 or n < k: return 0 return fac[n] * (finv[k] * finv[n - k] % mod) % mod init(10**5+10) N = int(input()) M = int(input()) ans = pow(2,N,mod) for i in range(M): ans = (ans-com(N,i,mod))%mod print(ans)