n,p=map(int,input().split()) c=0 for i in range(1,100): if p**i<=n: c+=n//(p**i) else: break print(pow(p,c,998244353))