n = int(input()); m = int(input()); mod = 998244353
inv = [0, 1]
for i in range(2, m+1): inv.append(-inv[mod%i]*(mod//i)%mod)
x = 1; ans = pow(2, n, mod)
for i in range(m): ans = (ans-x)%mod; x = x*(n-i)*inv[i+1]%mod
print(ans)