import static java.lang.System.err;

public class Main {
	public static void main(String[] args) {
		java.io.PrintWriter out = new java.io.PrintWriter(System.out);
		new Main(out);
		out.flush();
		err.flush();
	}

	public Main(java.io.PrintWriter out) {
		try (java.util.Scanner sc = new java.util.Scanner(System.in)) {
			long N = sc.nextLong(), P = sc.nextLong();
			int MOD = 998_244_353;
			long ans = 0;
			for (long i = P;i <= N;i *= P) ans += N / i;
			out.println(pow(P, ans, MOD));
		}
	}

	public static int pow(long a, long b, int mod) {
		if (b < 0) b = b % (mod - 1) + mod - 1;
		long ans = 1;
		for (long mul = a; b > 0; b >>= 1, mul = mul * mul % mod) if ((b & 1) != 0) ans = ans * mul % mod;
		return (int) ans;
	}
}