import sys sys.setrecursionlimit(5*10**5) input = sys.stdin.readline from collections import defaultdict, deque, Counter from heapq import heappop, heappush from bisect import bisect_left, bisect_right from math import gcd n,p = map(int,input().split()) mod = 998244353 now = p ans = 0 while now <= n: ans += n//now now *= p print(pow(p,ans, mod))