n = int(input())
m = int(input())
mod = 998244353 
ans = pow(2,n,mod)-1
acc1 = n
acc2 = 1
# print(ans)
for i in range(1,min(m,n+1)):
#     print(i,acc1,acc2)
    ans -= acc1 * pow(acc2,mod-2,mod)
    acc1 *= n-i
    acc2 *= (i+1)
    acc1 %= mod
    acc2 %= mod
    ans %= mod
#     print(acc1,acc2)
print(ans)