mod = 998244353 n, p = map(int, input().split()) ret = 0 while n > 0: ret += n // p n //= p print(pow(p, ret, mod))