# coding: utf-8 # Your code here! MOD=998244353 N,P=map(int,input().split()) cands=[P] while P*cands[-1]<=10**12: cands.append(P*cands[-1]) P_num=0 for c in cands: P_num+=N//c print(pow(P,P_num,MOD))