mod = 998244353 twoinv = 499122177 n,a = 0,0 def f(x): if a == 1 or (x < a): ans = x * (x - 1) // 2 return ans % mod b,c = x // a,x % a res = ((c + 1) * b) % mod pl = (x - b - 1) * (x - b) pl = (pl * twoinv) % mod return (res + pl + f(b)) % mod t = int(input()) while t > 0: n,a = map(int, input().split()) print(f(n)) t -= 1