mod = 998244353
N, K = [int(x) for x in input().split()]
ans = N * K
ans *= K - 1
ans %= mod
ans *= pow(K, (mod - 2) * N, mod)
ans %= mod
print(ans)