MOD=998244353 N,P=map(int,input().split()) def solve(N,P): N//=P ans=0 a=1 while N: ans+=a * (N%P) ans%=MOD a = a*P+1 N//=P return (pow(P, ans, MOD)) print(solve(N,P))