MOD = 998244353 N, P = map(int, input().split()) tmp = 0 for i in range(1, 100): if pow(P, i) > N: break tmp += N//pow(P, i) print(pow(P,tmp,MOD))