MOD = 998244353 N, P = map(int, input().split()) e = 0 n = N while n>0: e += n // P n //= P print(pow(P, e, MOD))