MOD = 998244353 N, P = map(int, input().split()) q = P c = 0 for i in range(100): if q > N: break c += N // q q *= P print(pow(P, c, MOD))