M,K=map(int,input().split()) fact=[1 for _ in range(M+1)] factinv=[1 for _ in range(M+1)] a=1 b=1 p=998244353 for i in range(M): a*=(i+1) a%=p b*=pow(i+1,p-2,p) b%=p fact[i+1],factinv[i+1]=a,b def comb(n,k,p): return fact[n]*(factinv[k]*factinv[n-k])%p ans=0 for i in range(M): ans+=pow(i,K,p)*comb(M-1,i,p)%p ans%=p print(ans)