import java.io.*; import java.util.ArrayList; import java.util.List; public class Main { public static void main(String[] args) throws Exception { BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); int t = Integer.parseInt(br.readLine()); int mod = 998244353; long m2 = modinv(2, mod); PrintWriter pw = new PrintWriter(System.out); for (int z = 0; z < t; z++) { String[] sa = br.readLine().split(" "); long n = Long.parseLong(sa[0]); long a = Long.parseLong(sa[1]); if (a == 1) { long ans = n % mod * ((n - 1) % mod) % mod * m2 % mod; pw.println(ans); continue; } List list1 = new ArrayList<>(); List list2 = new ArrayList<>(); long s = 0; long n2 = n; while (n2 > 0) { list1.add(s); list2.add(n2); long n3 = n2 / a; s += n2 - n3 * a + 1; n2 = n3; } list2.add(0L); long ans = 0; for (int i = 0; i < list1.size(); i++) { long v1 = list1.get(i); long d = list2.get(i) - list2.get(i + 1); long val = (v1 + v1 + d - 1) % mod * (d % mod) % mod * m2 % mod; ans += val; } ans %= mod; pw.println(ans); } pw.flush(); br.close(); } static long modinv(long a, int m) { long b = m; long u = 1; long v = 0; long tmp = 0; while (b > 0) { long t = a / b; a -= t * b; tmp = a; a = b; b = tmp; u -= t * v; tmp = u; u = v; v = tmp; } u %= m; if (u < 0) u += m; return u; } }