import sys input = sys.stdin.readline N,K=map(int,input().split()) mod=998244353 print(N*K*(K-1)*pow(pow(K,N,mod),mod-2,mod)%mod)