mod = 998244353 N=int(input()) M=int(input()) fac = [1] for i in range(1,M+5): fac.append(fac[-1]*i%mod) invfac = [1]*(M+5) invfac[M+4] = pow(fac[M+4],mod-2,mod) for i in range(M+3,0,-1): invfac[i] = invfac[i+1]*(i+1)%mod ans = pow(2,N,mod)-1 if M > 1: num = N ans -= num ans %= mod for i in range(2,M): num *= N-i+1 num %= mod ans -= num * invfac[i] ans %= mod print(ans)