import sys import threading def main(): import sys data = sys.stdin.read().split() it = iter(data) T = int(next(it)) mod = 998244353 out = [] for _ in range(T): N = int(next(it)) A = int(next(it)) if A == 1: # f(K) = N - K, sum = N*(N-1)/2 res = (N % mod) * ((N-1) % mod) % mod * pow(2, mod-2, mod) % mod out.append(str(res)) continue # Precompute contributions total_N2 = (N % mod) * (N % mod) % mod sum_m = 0 sum_PK = 0 P = 1 m = 0 while P <= N: U = N // P # Next power # Watch for overflow: if P > N//A, then P*A > N if P > N // A: L = 0 else: L = N // (P * A) count = U - L # sum of m contributions sum_m = (sum_m + m * (count % mod)) % mod # sum of P[m] * sum_{K=L+1..U} K # sumK = U*(U+1)//2 - L*(L+1)//2 sumK = (U * (U + 1) // 2 - L * (L + 1) // 2) % mod sum_PK = (sum_PK + (P % mod) * sumK) % mod # next m += 1 P *= A # result = N^2 + sum_m - sum_PK res = (total_N2 + sum_m - sum_PK) % mod out.append(str(res)) sys.stdout.write("\n".join(out)) if __name__ == '__main__': main()