## https://yukicoder.me/problems/no/2380 MOD = 998244353 def main(): N, P = map(int, input().split()) ans = 0 p = P while p <= N: x = p while x <= N: ans += 1 x += p p *= P print(pow(P, ans, MOD)) if __name__ == '__main__': main()