MOD = 998244353
N, P = map(int, input().split())

q = P
c = 0
for i in range(100):
    if q > N:
        break
    c += N // q
    q *= P

print(pow(P, c, MOD))