MOD = 998244353 N,K = map(int,input().split()) top = (N%MOD*K%MOD*(K-1)%MOD) bottom = pow(pow(K,N),MOD-2,MOD) print(top*bottom%MOD)