n = int(input()) m = int(input()) mod = 998244353 ans = pow(2,n,mod)-1 acc1 = n acc2 = 1 # print(ans) for i in range(1,min(m,n+1)): # print(i,acc1,acc2) ans -= acc1 * pow(acc2,mod-2,mod) acc1 *= n-i acc2 *= (i+1) acc1 %= mod acc2 %= mod ans %= mod # print(acc1,acc2) print(ans)