n,k=map(int,input().split()) mod=998244353 print(pow(k,-(n-1),mod)*(k-1)*n%mod)