n,k = map(int,input().split()) cnt=n*k*(k-1) all=k**n mod = 998244353 denominator = pow(all, mod-2, mod) print(cnt * denominator % mod)