def lagrange(n, p):
	ret = 0
	while n > 0:
		n //= p
		ret += n
	return ret

n, p = map(int,input().split())
print(pow(p, lagrange(n, p), 998244353))