N, K = map(int, input().split()) print(N * (K - 1) * pow(K, 998244353 - N, 998244353) % 998244353)