N,K = map(int, input().split()) mod = 998244353 def modinv(a): return pow(a, mod-2, mod) print(N*K*(K-1)*modinv(pow(K,N,mod))%mod)