import sys input = sys.stdin.readline mod=998244353 H,W,A,B=map(int,input().split()) def calc(H,A): ANS=A*2*(H-A+1)*(H-A+1) kasanari=0 for i in range(H-A+1): left=min(A,i) right=min(A,H-A-i) kasanari-=A kasanari+=(A+(A-left))*(left+1)*pow(2,mod-2,mod) kasanari+=(A+(A-right))*(right+1)*pow(2,mod-2,mod) #print(left,right) return kasanari*pow(H-A+1,mod-2,mod)*pow(H-A+1,mod-2,mod)%mod x=calc(H,A) y=calc(W,B) print((A*B*2-x*y)%mod)