import sys input = sys.stdin.readline MOD = 998244353 T = int(input()) NA = [list(map(int, input().split())) for _ in range(T)] for n, a in NA: if a==1: ans = (n-1)*n//2 print(ans%MOD) else: y = n ans = 0 m = [] first = 1 cnt = 0 while y>0: x = y//a if first: ans += (n-(x+1))*(n-x)//2 first = 0 else: ans += (y-(x+1))*(y-x)//2 ans += (y-x)*cnt + sum(m)*(y-x) m.append(y%a) y = x cnt += 1 #print(ans) print(ans%MOD)