from collections import deque MOD = 998244353 def solve_case(N, A): visited = {} q = deque() q.append((N, 0)) visited[N] = 0 while q: cur, d = q.popleft() if cur > 1 and (cur - 1) not in visited: visited[cur - 1] = d + 1 q.append((cur - 1, d + 1)) if A > 1 and cur % A == 0: nxt = cur // A if nxt not in visited: visited[nxt] = d + 1 q.append((nxt, d + 1)) if len(visited) > 10**6: break total = 0 for k, dist in visited.items(): if k <= N: total += dist if total >= MOD: total -= MOD counted = sorted(k for k in visited if k <= N) counted_set = set(counted) def sum_arith(l, r): cnt = r - l + 1 return ((cnt * ((N - l) + (N - r)) // 2) % MOD) last = 0 for k in counted: if k - 1 >= last + 1: total += sum_arith(last + 1, k - 1) total %= MOD last = k if last < N: total += sum_arith(last + 1, N) total %= MOD return total T = int(input()) for _ in range(T): N, A = map(int, input().split()) if A == 1: ans = (N * (N - 1) // 2) % MOD else: ans = solve_case(N, A) print(ans)