n, p = map(int, input().split()) mod = 998244353 ans = 1 while n: ans *= pow(p, n//p, mod) ans %= mod n //= p print(ans)