#include using namespace std; using ll = long long; #include using namespace atcoder; using mint = modint998244353; vector f,finv; void calc(int n){ f.resize(n); finv.resize(n); f[0] = 1; finv[0] = 1; for(int i = 1; i < n; i++){ f[i] = f[i - 1] * i; finv[i] = finv[i - 1] / i; } } mint comb(int n, int r){ return f[n] * finv[r] * finv[n - r]; } int main(){ calc(300000); ll n, p; cin >> n >> p; mint ans = 0; mint pre = 1; for(int i = n; i >= 0; i--){ if(n % p == i % p){ mint v = comb(n, i); int r = n - i; if(r) for(int j = 0; j < p; j++) pre *= r - j; mint tmp = pre; tmp /= mint(p).pow(r / p); tmp *= finv[r / p]; ans += v * tmp; } } ans = f[n] - ans; cout << ans.val() << '\n'; }