A=input().split() N=int(A[0]) K=int(A[1]) P=998244353 print((N*(K-1)*pow(K,P-N))%P)