from collections import * from itertools import * from functools import * from heapq import * import sys,math input = sys.stdin.readline N,P = map(int,input().split()) mod = 998244353 ans = 0 res = P while res <= N: ans += N//res res *= P print(pow(P,ans,mod))