import sys input=lambda:sys.stdin.readline().rstrip() mod=998244353 N,K=map(int,input().split()) def p(a,b): temp=a ans=1 while b>0: if b%2: ans=(ans*temp)%mod b//=2 temp=(temp*temp)%mod return ans C=1 ans=0 for l in range(N): ans=(ans+C*p(l,K))%mod C=C*(N-1-l)//(l+1) print(ans)