n,k = map(int,input().split()) mod = 998244353 si = n*(k-1)*k%mod ans = si * pow(pow(k, n, mod), mod-2, mod) print(ans % mod)