from collections import defaultdict, deque from sys import stdin readline = stdin.readline def li(): return list(map(int, readline().split())) mod = 998244353 def plus_mod(a, b): return (a + b) % mod def minus_mod(a, b): return (a - b) % mod def multiply_mod(a, b): return (a * b) % mod def pow_mod(i, p): if p == 0: return 1 a = pow_mod(i, p >> 1) a = (a * a) % mod if p & 1 == 1: a = (a * i) % mod return a def inv_mod(i): return pow_mod(i, mod - 2) # num / K**N # num = K**N * ans N, K = li() # N*(K-1)*K = K**N * ans n_inv = inv_mod(pow_mod(K, N)) ans = multiply_mod(N * (K - 1) * K % mod, n_inv) print(ans)