N, M, A, B = map(int, input().split()) n = 505050 mod = 998244353 fact = [1] * (n + 1) inv = [1] * (n + 1) for i in range(1, n): fact[i + 1] = ((i+1) * fact[i]) % mod inv[n] = pow(fact[n], mod - 2, mod) for i in range(n - 1, -1, -1): inv[i] = inv[i + 1] * (i + 1) % mod def comb(n, r): if n < 0 or r < 0 or n - r < 0: return 0 return fact[n] * inv[r] * inv[n - r] % mod ans = 0 for x in range(B - A * (N - 1) + 1): ans += comb(x + N - 2, x) * (M - A * (N - 1) - x) ans *= fact[N] ans %= mod print(ans)