MOD = 998244353 def main(): import sys input = sys.stdin.read data = input().split() N = int(data[0]) M = int(data[1]) A = int(data[2]) B = int(data[3]) low = A * (N - 1) high = min(B, M - 1) if low > high: print(0) return # Precompute factorial and inverse factorial up to required size max_n = (high - low) + (N - 1) + 2 # T_max + H can be up to high-low + (N-1) which is up to around 4e5 max_fact = max(max_n, N) # N! needed fact = [1] * (max_fact + 1) for i in range(1, max_fact + 1): fact[i] = fact[i-1] * i % MOD inv_fact = [1] * (max_fact + 1) inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD) for i in range(max_fact -1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD def comb(n, k): if n < 0 or k < 0 or n < k: return 0 return fact[n] * inv_fact[k] % MOD * inv_fact[n - k] % MOD H = N - 1 T_max = high - low a = M - low n = T_max + H comb1 = comb(n, H) comb2 = comb(n, H + 1) sum_val = (a * comb1 - H * comb2) % MOD if sum_val < 0: sum_val += MOD ans = sum_val * fact[N] % MOD print(ans) if __name__ == '__main__': main()