import sys input = sys.stdin.readline # https://qiita.com/AkariLuminous/items/3e2c80baa6d5e6f3abe9#4-floor_sum def floor_sum(n, m, a, b=0): ans = 0 while True: if a >= m or a < 0: ans += n * (n - 1) * (a // m) // 2 a %= m if b >= m or b < 0: ans += n * (b // m) b %= m y_max = a * n + b if y_max < m: break n, b, m, a = y_max // m, y_max % m, a, m return ans T = int(input()) mod = 998244353 for _ in range(T): N, M, L, R = map(int, input().split()) ans = M + 1 + floor_sum(M + 1, N - 1, -1, R) - floor_sum(M + 1, N - 1, -1, L + N - 2) print(ans % mod)