X=input().split() N=int(X[0]) P=int(X[1]) c=0 p=P while p<=N: N//=p c+=N print(pow(P,c,998244353))