import java.util.*; import java.io.*; import java.math.*; class Main{ static final int MOD7 = 1000000007; static final int MOD9 = 998244353 ; static final int inf = (1 << 30); static final long lnf = (1L << 60); final String yes = "Yes"; final String no = "No" ; void solve(PrintWriter out , In in) { MC mc = new MC(MOD9); long n = in.nextLong() , p = in.nextLong(); long kata = count(n , p); out.print(mc.power(p, kata)); } static long count(long M , long P){ if(M == 0) return 0; M /= P; return M + count(M,P); } public static void main(String[] args) { PrintWriter pw = new PrintWriter(System.out); In in = new In(); new Main().solve(pw,in); pw.flush(); } } class MC { private final int mod; public MC(final int mod) { this.mod = mod; } public long mod(long x) { x %= mod; if (x < 0) { x += mod; } return x; } public long add(final long a, final long b) { return mod(a + b); } public long mul(final long a, final long b) { return mod(a * b); } public long div(final long numerator, final long denominator) { return mod(numerator * inverse(denominator)); } public long power(long base, long exp) { long ret = 1; base %= mod; while (exp > 0) { if ((exp & 1) == 1) { ret = mul(ret, base); } base = mul(base, base); exp >>= 1; } return ret; } public long inverse(final long x) { return power(x, mod - 2); } public long factorial(final int n) { return product(1, n); } public long product(final int start, final int end) { long result = 1; for (int i = start; i <= end; i++) { result *= i; result %= mod; } return result; } public long combination(final int n, int r) { if (r > n) { return 0; } return div(product(n - r + 1, n), factorial(r)); } } class BinarySearch >{ int or_greater(ArrayList A , T key) { return A.size() - lower_bound(A, key); } int or_under(ArrayList A , T key) { return upper_bound(A, key); } int greater(ArrayList A , T key) { return A.size() - upper_bound(A, key); } int under(ArrayList A , T key) { return lower_bound(A, key); } private int lower_bound(ArrayList A, T key) { int left = 0; int right = A.size(); while(left < right) { int mid = (left + right) / 2; if(compare(A.get(mid), key , 0)) left = mid + 1; else right = mid; } return right; } private int upper_bound(ArrayList A , T key) { int left = 0; int right = A.size(); while(left < right) { int mid = (left + right) / 2; if(compare(A.get(mid) , key , 1)) left = mid + 1; else right = mid; } return right; } boolean compare(T o1, T o2 , int c) { int res = o1.compareTo(o2); return c == 0 ? res == -1 : (res == 0 || res == -1); } } class Pair, U extends Comparable> implements Comparable> { private final T first; private final U second; Pair(T first, U second) { this.first = first; this.second = second; } T first() { return first; } U second() { return second; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Pair pair = (Pair) o; return Objects.equals(first, pair.first) && Objects.equals(second, pair.second); } @Override public int hashCode() { return Objects.hash(first, second); } @Override public int compareTo(Pair other) { T small_First = this.first , large_First = other.first; U small_Second = this.second , large_Second = other.second; int First_Result = large_First.compareTo(small_First); int Second_Result = small_Second.compareTo(large_Second); return First_Result == 0 ? Second_Result : First_Result ; } @Override public String toString() { return this.first+" "+this.second; } } class Triple, U extends Comparable, V extends Comparable> implements Comparable> { private final T first; private final U second; private final V third; Triple(T first, U second, V third) { this.first = first; this.second = second; this.third = third; } T first() { return first; } U second() { return second; } V third() { return third; } @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Triple triple = (Triple) o; return Objects.equals(first, triple.first) && Objects.equals(second, triple.second) && Objects.equals(third, triple.third) ; } @Override public int hashCode() { return Objects.hash(first, second, third); } @Override public int compareTo(Triple other) { T small_First = this.first , large_First = other.first; U small_Second = this.second , large_Second = other.second; V small_Third = this.third , large_Third = other.third; int First_Result = small_First.compareTo(large_First); int Second_Result = small_Second.compareTo(large_Second); int Third_Result = small_Third.compareTo(large_Third); return First_Result == 0 ? (Second_Result == 0 ? Third_Result : Second_Result ) : First_Result; } } class In{ private final InputStream in = System.in; private final byte[] buffer = new byte[1024]; private int ptr = 0; private int buflen = 0; private boolean hasNextByte() { if (ptr < buflen) { return true; }else{ ptr = 0; try { buflen = in.read(buffer); } catch (IOException e) { e.printStackTrace(); } if (buflen <= 0) { return false; } } return true; } private int readByte() { if (hasNextByte()) return buffer[ptr++]; else return -1; } private static boolean isPrintableChar(int c) { return 33 <= c && c <= 126; } public boolean hasNext() { while(hasNextByte() && !isPrintableChar(buffer[ptr])) { ptr++; } return hasNextByte(); } String next() { if (!hasNext()) throw new NoSuchElementException(); StringBuilder sb = new StringBuilder(); int b = readByte(); while(isPrintableChar(b)) { sb.appendCodePoint(b); b = readByte(); } return sb.toString(); } long nextLong() { if (!hasNext()) throw new NoSuchElementException(); long n = 0; boolean minus = false; int b = readByte(); if (b == '-') { minus = true; b = readByte(); } if (b < '0' || '9' < b) { throw new NumberFormatException(); } while(true){ if ('0' <= b && b <= '9') { n *= 10; n += b - '0'; }else if(b == -1 || !isPrintableChar(b)){ return minus ? -n : n; }else{ throw new NumberFormatException(); } b = readByte(); } } int nextInt() { long nl = nextLong(); if (nl < Integer.MIN_VALUE || nl > Integer.MAX_VALUE) throw new NumberFormatException(); return (int) nl; } double nextDouble() { return Double.parseDouble(next()); } char nextChar() { return next().charAt(0); } int [] IntArray(int n) { final int [] Array = new int [n]; for(int i = 0 ; i < n ; i ++ ) { Array[i] = nextInt(); } return Array; } int [][] IntArray(int n , int m) { final int [][] Array = new int [n][m]; for(int i = 0 ; i < n ; i ++ ) { Array[i] = IntArray(m); } return Array; } long [] LongArray(int n) { final long [] Array = new long [n]; for(int i = 0 ; i < n ; i ++ ) { Array[i] = nextLong(); } return Array; } long [][] LongArray(int n , int m) { final long [][] Array = new long [n][m]; for(int i = 0 ; i < n ; i ++ ) { Array[i] = LongArray(m); } return Array; } String [] StringArray(int n) { final String [] Array = new String [n]; for(int i = 0 ; i < n ; i ++ ) { Array[i] = next(); } return Array; } char [] CharArray(int n) { final char [] Array = new char[n]; for(int i = 0 ; i < n ; i ++ ) { Array[i] = next().charAt(0); } return Array; } char [][] CharArray(int n , int m) { final char [][] Array = new char [n][m]; for(int i = 0 ; i < n ; i ++ ) { Array[i] = next().toCharArray(); } return Array; } char [][] CharArray2(int n , int m) { final char [][] Array = new char [n][m]; for(int i = 0 ; i < n ; i ++ ) { for(int j = 0 ; j < n ; j ++ ) { Array[i][j] = next().charAt(0); } } return Array; } }