n,k = map(int,input().split()) mod = 998244353 #nCk def com(n,mod): fact = [1,1] factinv = [1,1] inv = [0,1] for i in range(2,n+1): fact.append((fact[-1]*i)%mod) inv.append((-inv[mod%i]*(mod//i))%mod) factinv.append((factinv[-1]*inv[-1])%mod) return fact, factinv f,fi = com(n+10,mod) def ncr(n,r): return f[n]*fi[r]%mod*fi[n-r]%mod ans = 0 for i in range(n): tmp = ncr(n-1, i) ans += tmp * pow(i,k,mod)%mod ans %= mod print(ans)