# coding: utf-8
# Your code here!
MOD=998244353

N,P=map(int,input().split())

cands=[P]

while P*cands[-1]<=10**12:
    cands.append(P*cands[-1])

P_num=0

for c in cands:
    P_num+=N//c

print(pow(P,P_num,MOD))