MOD = 998244353 def main(): import sys input = sys.stdin.read N, K = map(int, input().split()) M = N - 1 # Precompute factorial and inverse factorial max_fact = M fact = [1] * (max_fact + 1) for i in range(1, max_fact + 1): fact[i] = fact[i-1] * i % MOD inv_fact = [1] * (max_fact + 1) inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD) for i in range(max_fact-1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD ans = 0 for m in range(0, M + 1): if m == 0: term = 0 # 0^K is 0 for K >= 1 else: # Compute C(M, m) cm = fact[M] * inv_fact[m] % MOD cm = cm * inv_fact[M - m] % MOD # Compute m^K mod MOD mk = pow(m, K, MOD) term = cm * mk % MOD ans = (ans + term) % MOD print(ans) if __name__ == "__main__": main()