N,K=map(int,input().split())
MOD=998244353
ans=N*K*(K-1)%MOD
ans*=pow(K,N*(MOD-2),MOD)
ans%=MOD
print(ans)