def Map(): return list(map(int,input().split())) mod = 998244353 H,W,A,B = Map() p = 0 pp = 0 for h in range(H): l = max(0,h+1-A) r = min(h,H - A) x = r - l + 1 p += x pp += pow(x,2,mod) p %= mod pp %= mod q = 0 qq = 0 for w in range(W): l = max(0,w+1-B) r = min(w,W - B) x = r - l + 1 q += x qq += pow(x,2,mod) q %= mod qq %= mod n = (H - A + 1) * (W - B + 1) % mod m = pow(n,2,mod) C = 2*p*q*n C %= mod C -= pp*qq C %= mod C *= pow(m,-1,mod) C %= mod print(C)