n, m, k = map(int, input().split()); mod = 998244353 print((pow(k, 2*n, mod)*(m-k+1)-pow(k-1, 2*n, mod)*(m-k))%mod if k < m else pow(m, 2*n, mod))