def prime_factorize(n): a = [] while n % 2 == 0: a.append(2) n //= 2 f = 3 while f * f <= n: if n % f == 0: a.append(f) n //= f else: f += 2 if n != 1: a.append(n) return a MOD = 998244353 n, p = map(int, input().split()) v = 0 for i in range(1, n): q = pow(p, i) if q > n: break v += n // q print(pow(p, v, MOD))