N, P = map(int, input().split()) MOD = 998244353 e = 0 m = N while m > 0: e += m // P m //= P print(pow(P, e, MOD))