n,k = map(int,input().split()) mod = 998244353 ans = n*k*(k-1)*pow(pow(k,n,mod),mod-2,mod)%mod print(ans)