N,K=map(int,input().split()) MOD=998244353 f=[1]*N inv=[1]*N finv=[1]*N for i in range(2,N): f[i]=f[i-1]*i%MOD inv[i]=MOD-(inv[MOD%i]*(MOD//i)%MOD) finv[i]=finv[i-1]*inv[i]%MOD ans=0 for i in range(N): v=pow(i,K,MOD)*f[N-1]%MOD v*=finv[i]*finv[N-i-1]%MOD ans+=v%MOD ans%=MOD print(ans)