N,K=map(int, input().split()) mod=998244353 ans=K*(K-1)*N%mod def xgcd(a, b): x0, y0, x1, y1 = 1, 0, 0, 1 while b != 0: q, a, b = a // b, b, a % b x0, x1 = x1, x0 - q * x1 y0, y1 = y1, y0 - q * y1 return a, x0, y0 def modinv(a, m): g, x, y = xgcd(a, m) if g != 1: raise Exception('modular inverse does not exist') else: return x % m gt=modinv(K,mod) for i in range(N): ans*=gt ans%=mod print(ans)