n,k=map(int,input().split()) M=998244353 print(n*k*(k-1)*pow(pow(k,M-2,M),n,M)%M)