def floor_sum(n, m, a, b): ans = 0 while True: if a >= m: ans += (n * (n - 1) >> 1) * (a // m) a %= m if b >= m: ans += n * (b // m) b %= m y = a * n + b if y < m: return ans n, b, m, a = y // m, y % m, a, m for _ in range(int(input())): n, m, l, r = map(int, input().split()) l += n - 2 ans = m + 1 ans += floor_sum(r + 1, n - 1, 1, 0) ans -= floor_sum(l + 1, n - 1, 1, 0) ans += floor_sum(l - m, n - 1, 1, 0) print(ans % 998244353)