MOD = 998244353 N,K = map(int,input().split()) ans = pow(pow(K,N,MOD),MOD-2,MOD) ans *= K * (K-1) * N ans %= MOD print(ans)