import sys
input = sys.stdin.readline

N,K=map(int,input().split())
mod=998244353

print(N*K*(K-1)*pow(pow(K,N,mod),mod-2,mod)%mod)