N, P = map(int, input().split()) Pi = P MOD = 998244353 fac = 0 while Pi <= N: fac += N // Pi Pi *= P print(pow(P, fac, MOD))