MOD=998244353 N,P=map(int,input().split()) def solve(N,P): ans=0 while N: ans+= N//P N//=P return (pow(P, ans, MOD)) print(solve(N,P))