n,k=map(int,input().split()) mod=998244353 ans=n*k*(k-1)*pow(pow(k,n,mod),mod-2,mod)%mod print(ans)