N, P = map(int,input().split()) rank = 0 while N > 0: rank += N // P N //= P print(pow(P, rank, 998244353))