class Counting(): def __init__(self,n,mod): self.fac_arr = [1] * (n+1) self.r_s_fac_arr = [1] * (n+1) self.mod = mod for i in range(1,n+1): self.fac_arr[i] = (self.fac_arr[i-1]) * i % mod self.fac_arr[i] %= mod self.r_s_fac_arr[n] = pow(self.fac_arr[n],mod-2,mod) for i in range(n-1,-1,-1): self.r_s_fac_arr[i] = (self.r_s_fac_arr[i+1] % mod) * (i+1) self.r_s_fac_arr[i] %= mod def nckMod(self,n,k): mod = self.mod return self.fac_arr[n] % mod * self.r_s_fac_arr[k] % mod * self.r_s_fac_arr[n-k] % mod def npkMod(self,n,k): mod = self.mod return self.fac_arr[n] % mod * self.r_s_fac_arr[n-k] % mod n,k = map(int,input().split()) ans = 0 mod = 998244353 ct = Counting(n,mod) for i in range(1,n): # print(i,n-1,ct.nckMod(n-1,i)) ans += pow(i,k,mod) * ct.nckMod(n-1,i) ans %= mod print(ans)