MOD = 998244353
N, P = map(int, input().split())

cnt = 0
for i in range(1, N+1):
    tmp = P**i
    if tmp>N:
        break
    cnt += N//tmp

ans = pow(P, cnt, MOD)

print(ans)