def modfac(n, MOD): f = 1 factorials = [1] for m in range(1, n + 1): f *= m f %= MOD factorials.append(f) inv = pow(f, MOD - 2, MOD) invs = [1] * (n + 1) invs[n] = inv for m in range(n, 1, -1): inv *= m inv %= MOD invs[m - 1] = inv return factorials, invs def modnCr(n,r): #上で求めたfacとinvsを引数に入れるべし(上の関数で与えたnが計算できる最大のnになる) return fac[n] * inv[n-r] * inv[r] % mod mod = 998244353 fac,inv = modfac(200000,mod) N,K = map(int,input().split()) ans = K * (K-1) * N al = pow(K,N,mod) ans *= pow(al,mod-2,mod) print ( ans % mod )