n,k = map(int,input().split()) mod = 998244353 print(n*(k-1)*pow(k,-n+1,mod)%mod)