n, p = map(int, input().split())
e = 0
while n:
    n //= p
    e += n
ans = pow(p, e, 998244353)
print(ans)