#include<iostream>
using namespace std;
constexpr long mod=998244353;

long modpow(long a,long b,long c){
	long res=1;
	while(b>0){
		if(b&1)res=res*a%c;
		a=a*a%c;
		b>>=1;
	}
	return res;
}

int main(){
	long N,K;cin>>N>>K;
	long ans=1;
	for(int i=0;i<N;++i){
		ans=ans*modpow(K,mod-2,mod)%mod;
	}
	ans*=(K-1)*K*N%mod;
	ans%=mod;
	cout<<ans<<endl;
}