n,p=map(int,input().split())
mod=998244353
tot=0
i=1
while p**i<=n:
  tot+=n//(p**i)
  i+=1
print(pow(p,tot,mod))