X=input().split() N=int(X[0]) P=int(X[1]) c=0 while P<=N: N//=P c+=N print(pow(P,c,998244353))