import sys input = sys.stdin.readline #import collections #from bisect import * from heapq import * def linput(ty=int, cvt=list): return cvt(map(ty,input().split())) def yesno(b): print("No Yes".split()[b]) def gcd(n,m): while m: n,m = m, n%m return n def lcm(n,m): return n*m//gcd(n,m) def prime_factorize(n): r = [] rapp = r.append while n%2<1: rapp(2); n//=2 f = 3 while f*f<=n: if n%f<1: rapp(f); n//=f else: f+=2 if n != 1: rapp(n) return r def main(): N,P = linput() M = 998244353 #res = 0 #vQ = prime_factorize(N) x = N t = 0 while x>0: x = x//P t += x res = pow(P,t,M) print(res) main()