mod = 998244353 Fsize = 2 * 10 ** 5 + 100 fact = [1] * (Fsize + Fsize + 5) for i in range(2, Fsize+1): fact[i] = fact[i-1] * i % mod fact[-Fsize] = pow(fact[Fsize], mod-2, mod) for i in reversed(range(2, Fsize + 1)): fact[-i+1] = fact[-i] * i % mod def comb(n, k): if k < 0 or k > n: return 0 return fact[n] * fact[-k] % mod * fact[-(n-k)] % mod N, K = map(int, input().split()) ans = 0 for i in range(N): ans += pow(i, K, mod) * comb(N - 1, i) % mod ans %= mod print(ans)