n = int(input()); m = int(input()); mod = 998244353 x = 1; ans = pow(2, n, mod) for i in range(m): ans = (ans-x)%mod; x = x*(n-i)//(i+1)%mod print(ans)