N, P = map(int, input().split()) MOD = 998244353 cnt = 0 while P <= N: cnt += N//P N //= P print(pow(P, cnt, MOD))