H, W, A, B = map(int, input().split()) mod = 998244353 alpha, beta = pow(H-A+1, mod-2, mod), pow(W-B+1, mod-2, mod) p = [0] * H q = [0] * W for i in range(H): l = max(0, i-A+1) r = min(H-1, i+A-1) p[i] = (r-l-A+2)*alpha%mod for j in range(W): l = max(0, j-B+1) r = min(W-1, j+B-1) q[j] = (r-l-B+2)*beta%mod sp, sq = 0, 0 sp2, sq2 = 0, 0 for v in p: sp += v sp2 += v*v sp = sp%mod sp2 = sp2%mod for v in q: sq += v sq2 += v*v sq = sq%mod sq2 = sq2%mod print((2*sp*sq - sp2*sq2)%mod)