# fpfpf n, p = map(int, input().split()) mod = 998244353 m = 0 for i in range(1, 20): if n // pow(p, i) == 0: break m += n // pow(p, i) fac0 = 1 fac1 = 1 for i in range(n): fac0 *= i + 1 fac0 %= mod fac1 *= i + 1 fac1 %= mod - 1 ans = m * pow(fac0, fac1, mod) % mod print(ans)