MOD = 998244353 N, P = map(int, input().split()) P_ = P e = 0 cnt = 0 while N >= P_: e += N // P_ * pow(2, cnt) P_ = P_ * P_ cnt += 1 print(pow(P, e, MOD))