n,p=map(int,input().split()) mod=998244353 tot=0 i=1 while p**i<=n: tot+=n//(p**i) i+=1 print(pow(p,tot,mod))