def resolve(): import sys input = sys.stdin.readline MOD = 998244353 n, p = map(int, input().split()) ans = 0 x = p while x <= n: ans += n // x ans %= MOD x *= p print(pow(p, ans, MOD) % MOD) if __name__ == "__main__": resolve()