n, p = map(int, input().split()) mod = 998244353 ans = 1 num = p while num <= n: q = n // num ans *= pow(p, q, mod) ans %= mod num *= p print(ans)