MOD=998244353 N,P=map(int,input().split()) N//=P ans=0 a=1 while N: ans+=a * (N%P) ans%=MOD a = a*P+1 N//=P print(pow(P, ans, MOD))