N,P = map(int,input().split())
mod = 998244353

t = 0
now = P
while now <= N:
	t += N // now
	now *= P
print(pow(P,t,mod))