## https://yukicoder.me/problems/no/2380 MOD = 998244353 def main(): N, P = map(int, input().split()) ans = 0 p = P while p <= N: ans += N // p p *= P print(pow(P, ans, MOD)) if __name__ == '__main__': main()