#include #include using namespace std; using mint = atcoder::modint998244353; int main() { ios::sync_with_stdio(false); cin.tie(nullptr); int n, p; cin >> n >> p; vector fact(n + 1, 1); for (int i = 2; i <= n; i++) { fact[i] = fact[i - 1] * i; } mint ans = 1; mint pat = 1; for (int i = 0; i < n / p; i++) { pat *= fact[n - i * p] / (fact[p] * fact[n - (i + 1) * p]); pat *= fact[p - 1]; pat /= (i + 1); ans += pat; } ans = fact[n] - ans; cout << ans.val() << endl; }