#!/usr/bin/env python3 # from typing import * import sys import io import math import collections import decimal import itertools import bisect import heapq def input(): return sys.stdin.readline()[:-1] # sys.setrecursionlimit(1000000) # _INPUT = """# paste here... # """ # sys.stdin = io.StringIO(_INPUT) INF = 10**10 MOD = 998244353 def solve(N, P): ans = 0 k = 0 p1 = P while p1 <= N: k += (N // p1) p1 *= P return pow(P, k, MOD) N, P = map(int, input().split()) print(solve(N, P))