N,P=map(int,input().split()) ci=0 while N!=0: N//=P ci+=N p=998244353 ans=1 for i in range(ci): ans*=P ans%=p print(ans)