def calc(n, X, Y): x = -(-X // n) * n y = Y // n * n if x > y: return (Y - X + 1) * (Y // n) else: tmp = (x - X) * (X // n) tmp += (Y - y + 1) * (Y // n) tmp += (y - x) * (x // n + (y - 1) // n) // 2 return tmp % MOD def solve(n, m, l, r): N1 = n - 1 ans1 = r - l + 1 ans2 = calc(N1, l, r) ans3 = calc(N1, m - r, m - l) ans = ans1 + ans2 + ans3 ans %= MOD return ans MOD = 998244353 T = int(input()) for t in range(T): N, M, L, R = map(int, input().split()) print(solve(N,M,L,R)) ''' N = 3 N1 = N - 1 M = 10 for l in range(11): for r in range(l, 11): tmp1 = r - l + 1 tmp2 = 0 tmp3 = 0 for i in range(l, r + 1): tmp2 += i // N1 tmp3 += (M - i) // N1 tmp = tmp1 + tmp2 + tmp3 a = solve(N, M, l, r) print(l, r, " ", *a, " ", tmp, tmp1, tmp2, tmp3) '''