n,p = map(int,input().split()) mod = 998244353 num = 0 now = p while now <= n: num += n//now now *= p print(pow(p,num,mod))