mod = 998244353 N,P = map(int,input().split()) now = 0 nP = P while (nP <= N): now += N // nP nP *= P print ( pow(P,now,mod) )