N,K = map(int,input().split())
P = 998244353

ans = K * (K - 1) * N % P
r = pow(K,N,P)
r = pow(r,P - 2,P)
print(ans * r % P)