#include using namespace std; using ll = long long; const ll MOD = 998244353; ll modpow(ll x, ll y, ll m) { if (!y) return 1; x %= m; if (y % 2) return modpow(x, y - 1, m) * x % m; return modpow(x * x % m, y / 2, m); } int main() { cin.tie(0); ios_base::sync_with_stdio(0); ll n, p; cin >> n >> p; vector f(n + 1); f[0] = 1; for (ll i = 1; i <= n; i++) { f[i] = f[i - 1] * i % MOD; } ll ans = f[n]; for (ll i = 0; i <= n / p; i++) { ll t = f[n] * modpow(f[i] * f[n - i * p] % MOD * modpow(p, i, MOD) % MOD, MOD - 2, MOD) % MOD; ans = (ans - t + MOD) % MOD; } cout << ans << '\n'; return 0; }