def exp(a,b,p): ans,mul,div=1,a,1 for i in range(70): if b//div%2==1: ans*=mul ans%=p mul=mul**2%p div*=2 return ans%p N,P=map(int,input().split()) ci=0 p=998244353 while N!=0: N//=P ci+=N ci%=p ans=1 print(exp(P,ci%p,p)%p)