MOD=998244353 n,k=map(int,input().split()) print(n*k*(k-1)*pow(pow(k,n,MOD),MOD-2,MOD)%MOD)