N,P = map(int,input().split()) mod = 998244353 t = 0 now = P while now <= N: t += N // now now *= P print(pow(P,t,mod))