n, m, k = map(int, input().split()) mod = 998244353 print((pow(k, 2*n, mod)*(m-k+1)-pow(k-1, 2*n, mod)*(m-k))%mod)