N,K = map(int,input().split())

MOD = 998244353
a = K*(K-1) * N
b = pow(pow(K,N,MOD),MOD-2,MOD)

print(a*b % MOD)