MOD = 998244353 N, P = map(int, input().split()) cnt = 0 for i in range(2, N+1): st = P**(i-1) - 1 ed = P**i - 1 if st>=N: break ed = min(ed, N) cnt += (ed//P - st//P) * (i-1) ans = pow(P, cnt, MOD) print(ans)