N,P=map(int,input().split()) ans=0 mod=998244353 while N: ans+=N//P N//=P ans=pow(P,ans,mod) print(ans)