n,p = map(int,input().split()) cnt = 0 for i in range(1,100): x = n // pow(p,i) cnt += x if x == 0: break print(pow(p,cnt,998244353))