n, p = map(int, input().split()) mod = 998244353 fact = [1] * (n + 1) inv = [1] * (n + 1) finv = [1] * (n + 1) for i in range(2, n + 1): fact[i] = fact[i - 1] * i % mod inv[i] = mod - inv[mod % i] * (mod // i) % mod finv[i] = finv[i - 1] * inv[i] % mod inv_p = pow(p, mod - 2, mod) ans = 0 max_k = n // p for k in range(max_k + 1): res = fact[n] * finv[k] % mod res *= finv[n - k * p] * pow(inv_p, k, mod) % mod res %= mod ans += res ans %= mod ans = (fact[n] - ans) % mod print(ans)