n = int(input()); m = int(input()); mod = 998244353 inv = [0, 1] for i in range(2, m+1): inv.append(-inv[mod%i]*(mod//i)%mod) x = 1; ans = pow(2, n, mod) for i in range(m): ans = (ans-x)%mod; x = x*(n-i)*inv[i+1]%mod print(ans)