MOD = 998244353 def solve_case(N, A): N_mod = N % MOD if A == 1: return N_mod * ((N - 1) % MOD) * pow(2, MOD-2, MOD) % MOD # 1) P[i]=A^i を列挙 P = [1] while P[-1] <= N // A: P.append(P[-1] * A) m = len(P) - 1 sum_p = 0 sum_KA = 0 for i, Pi in enumerate(P): Ri = N // Pi Li = 1 if i == m else (N // P[i+1]) + 1 if Li > Ri: continue cnt = Ri - Li + 1 # 区間和 (Li + Ri)*cnt/2 mod sK = (Li + Ri) * cnt // 2 % MOD sum_p = (sum_p + i % MOD * cnt) % MOD sum_KA = (sum_KA + (Pi % MOD) * sK) % MOD ans = (sum_p + N_mod * N_mod - sum_KA) % MOD return ans # 入力処理 import sys input = sys.stdin.readline T = int(input()) for _ in range(T): N, A = map(int, input().split()) print(solve_case(N, A))