MOD = 998244353 N, P = map(int, input().split()) P_ = P e = 0 while N >= P_: e += N // P_ P_ *= P print(pow(P, e, MOD))