MOD = 998244353
N, P = map(int, input().split())
P_ = P
e = 0
while N >= P_:
    e += N // P_
    P_ = P_ * P_
print(pow(P, e, MOD))