MOD = 998244353 N,P = map(int,input().split()) e = 0 x = 1 while P ** x <= N: e += N // (P ** x) x += 1 print(pow(P,e,MOD))