MOD = 998244353 INV2 = (MOD + 1) // 2 def solve(n, a): if a == 1: return (n * (n - 1) * INV2) % MOD total = 0 m_prev = n s_prev = 0 while m_prev > 0: r, m = m_prev % a, m_prev // a d = (m_prev - m) % MOD total += d * s_prev + (d * (d - 1)) % MOD * INV2 total %= MOD s_prev += r + 1 m_prev = m return total t = int(input()) for _ in range(t): n, a = map(int, input().split()) print(solve(n, a))