def combs_mod(n,k,mod): #kCkからnCkまで inv = [1]*(n-k+1) for i in range(1,n-k+1): inv[i] = pow(i,mod-2,mod) ans = [1]*(n+1) for i in range(k+1,n+1): ans[i] = ans[i-1]*i*inv[i-k]%mod return ans def solve(): N, M, A, B = map(int, input().split()) mod = 998244353 p = B-(N-1)*A if p<0: return 0 lis = combs_mod(p+N-1,N-1,mod) ans = lis[p+N-1]*(M-B)%mod for i in range(p): ans += lis[i+N-1] ans %= mod for i in range(2,N+1): ans *= i ans %= mod return ans print(solve())