def f(x, N): v = 0 inv = pow(N - x + 1, mod - 2, mod) for i in range(1, N + 1): v += (min(i, x, N - x + 1, N - i + 1) * inv) ** 2 v %= mod return v H, W, A, B = map(int, input().split()) mod = 998244353 ans = 2 * A * B - f(A, H) * f(B, W) print(ans % mod)