def lagrange(n, p): ret = 0 while n > 0: n //= p ret += n return ret n, p = map(int,input().split()) print(pow(p, lagrange(n, p), 998244353))