mod = 998244353 def digsum(n,a): res = 0 while n > 0: res += n % a n //= a return res def solve(n,a): if a == 1: return (n * (n - 1) // 2) % mod k,power,ans = 0,1,0 while power <= n: f = (n // power) + digsum(n % power,a) + k f %= mod ans += f * ((n // power) - n // (power * a)) ans %= mod power *= a k += 1 ans -= (n * (n + 1) // 2) % mod return (ans + mod) % mod t = int(input()) for _ in range(t): n, a = map(int, input().split()) print(solve(n,a))