N,P=map(int,input().split()) mod=998244353 ANS=1 for i in range(1,65): x=N//(P**i) ANS*=pow(P,x,mod) print(ANS%mod)