//https://suisen-kyopro.hatenablog.com/entry/2023/11/22/201600 よりお借りした #include #include template ::value, std::nullptr_t> = nullptr> std::vector arbitrary_mod_convolution(const std::vector& a, const std::vector& b) { int n = int(a.size()), m = int(b.size()); { // check if the mod is ntt-friendly int maxz = 1; while (not ((mint::mod() - 1) & maxz)) { maxz <<= 1; } int z = 1; while (z < n + m - 1) { z <<= 1; } if (z <= maxz) { return atcoder::convolution(a, b); } } if (n == 0 or m == 0) return {}; //if (std::min(n, m) <= 120) return atcoder::internal::convolution_naive(a, b); static constexpr long long MOD1 = 754974721; // 2^24 static constexpr long long MOD2 = 167772161; // 2^25 static constexpr long long MOD3 = 469762049; // 2^26 static constexpr long long M1M2 = MOD1 * MOD2; static constexpr long long INV_M1_MOD2 = atcoder::internal::inv_gcd(MOD1, MOD2).second; static constexpr long long INV_M1M2_MOD3 = atcoder::internal::inv_gcd(M1M2, MOD3).second; std::vector a2(n), b2(m); for (int i = 0; i < n; ++i) a2[i] = a[i].val(); for (int i = 0; i < m; ++i) b2[i] = b[i].val(); auto c1 = atcoder::convolution(a2, b2); auto c2 = atcoder::convolution(a2, b2); auto c3 = atcoder::convolution(a2, b2); const long long m1m2 = mint(M1M2).val(); std::vector c(n + m - 1); for (int i = 0; i < n + m - 1; ++i) { // Garner's Algorithm // X = x1 + x2 * m1 + x3 * m1 * m2 // x1 = c1[i], x2 = (c2[i] - x1) / m1 (mod m2), x3 = (c3[i] - x1 - x2 * m1) / m2 (mod m3) long long x1 = c1[i]; long long x2 = (atcoder::static_modint(c2[i] - x1) * INV_M1_MOD2).val(); long long x3 = (atcoder::static_modint(c3[i] - x1 - x2 * MOD1) * INV_M1M2_MOD3).val(); c[i] = x1 + x2 * MOD1 + x3 * m1m2; } return c; } template ::value, std::nullptr_t> = nullptr> struct factorial { using value_type = mint; factorial() = delete; static value_type fact(int n) { ensure(n + 1); return _fact[n]; } static value_type inv_fact(int n) { ensure(n + 1); return _inv_fact[n]; } static value_type binom(int n, int r) { if (r < 0 or r > n) return 0; return fact(n) * inv_fact(r) * inv_fact(n - r); } static value_type perm(int n, int r) { if (r < 0 or r > n) return 0; return fact(n) * inv_fact(n - r); } static void ensure(int size) { const int curr_size = _fact.size(); if (size <= curr_size) return; const int next_size = std::max(curr_size * 2, size); _fact.resize(next_size); _inv_fact.resize(next_size); for (int i = curr_size; i < next_size; ++i) { _fact[i] = _fact[i - 1] * i; } _inv_fact.back() = _fact.back().inv(); for (int i = next_size - 1; i > curr_size; --i) { _inv_fact[i - 1] = _inv_fact[i] * i; } } private: static inline std::vector _fact{ 1 }, _inv_fact{ 1 }; }; /** * Computes f(t),f(t+1),...,f(t+m-1) from f(0),f(1),...,f(n-1) */ template , std::is_invocable_r, Convolve, std::vector, std::vector> >, std::nullptr_t> = nullptr> std::vector shift_of_sampling_points(const std::vector& ys, mint t, int m, const Convolve& convolve) { const int n = ys.size(); factorial::ensure(std::max(n, m)); std::vector b = [&] { std::vector f(n), g(n); for (int i = 0; i < n; ++i) { f[i] = ys[i] * factorial::inv_fact(i); g[i] = (i & 1 ? -1 : 1) * factorial::inv_fact(i); } std::vector b = convolve(f, g); b.resize(n); return b; }(); std::vector e = [&] { std::vector c(n); mint prd = 1; std::reverse(b.begin(), b.end()); for (int i = 0; i < n; ++i) { b[i] *= factorial::fact(n - i - 1); c[i] = prd * factorial::inv_fact(i); prd *= t - i; } std::vector e = convolve(b, c); e.resize(n); return e; }(); std::reverse(e.begin(), e.end()); for (int i = 0; i < n; ++i) { e[i] *= factorial::inv_fact(i); } std::vector f(m); for (int i = 0; i < m; ++i) { f[i] = factorial::inv_fact(i); } std::vector res = convolve(e, f); res.resize(m); for (int i = 0; i < m; ++i) { res[i] *= factorial::fact(i); } return res; } /** * Computes f(t),f(t+1),...,f(t+m-1) from f(0),f(1),...,f(n-1) */ template ::value, std::nullptr_t> = nullptr> std::vector shift_of_sampling_points(const std::vector& ys, mint t, int m) { const auto convolve = [](const std::vector& a, const std::vector& b) { return atcoder::convolution(a, b); }; return shift_of_sampling_points(ys, t, m, convolve); } template ::value, std::nullptr_t> = nullptr> struct factorial_large { using value_type = mint; static constexpr int LOG_BLOCK_SIZE = 9; static constexpr int BLOCK_SIZE = 1 << LOG_BLOCK_SIZE; static constexpr int BLOCK_NUM = value_type::mod() >> LOG_BLOCK_SIZE; static inline int threshold = 2000000; factorial_large() = delete; static value_type fact(int n) { return n <= threshold ? factorial::fact(n) : _large_fact(n); } static value_type inv_fact(int n) { return n <= threshold ? factorial::inv_fact(n) : _large_fact(n).inv(); } static value_type binom(int n, int r) { if (r < 0 or r > n) return 0; return fact(n) * inv_fact(r) * inv_fact(n - r); } static value_type perm(int n, int r) { if (r < 0 or r > n) return 0; return fact(n) * inv_fact(n - r); } private: static inline std::vector _block_fact{}; static void _build() { if (_block_fact.size()) { return; } std::vector f{ 1 }; f.reserve(BLOCK_SIZE); for (int i = 0; i < LOG_BLOCK_SIZE; ++i) { std::vector g = shift_of_sampling_points(f, 1 << i, 3 << i, arbitrary_mod_convolution); const auto get = [&](int j) { return j < (1 << i) ? f[j] : g[j - (1 << i)]; }; f.resize(2 << i); for (int j = 0; j < 2 << i; ++j) { f[j] = get(2 * j) * get(2 * j + 1) * ((2 * j + 1) << i); } } // f_B(x) = (x+1) * ... * (x+B-1) if (BLOCK_NUM > BLOCK_SIZE) { std::vector g = shift_of_sampling_points(f, BLOCK_SIZE, BLOCK_NUM - BLOCK_SIZE, arbitrary_mod_convolution); std::move(g.begin(), g.end(), std::back_inserter(f)); } else { f.resize(BLOCK_NUM); } for (int i = 0; i < BLOCK_NUM; ++i) { f[i] *= value_type(i + 1) * BLOCK_SIZE; } // f[i] = (i*B + 1) * ... * (i*B + B) f.insert(f.begin(), 1); for (int i = 1; i <= BLOCK_NUM; ++i) { f[i] *= f[i - 1]; } _block_fact = std::move(f); } static value_type _large_fact(int n) { _build(); value_type res; int q = n / BLOCK_SIZE, r = n % BLOCK_SIZE; if (2 * r <= BLOCK_SIZE) { res = _block_fact[q]; for (int i = 0; i < r; ++i) { res *= value_type::raw(n - i); } } else if (q != factorial_large::BLOCK_NUM) { res = _block_fact[q + 1]; value_type den = 1; for (int i = 1; i <= BLOCK_SIZE - r; ++i) { den *= value_type::raw(n + i); } res /= den; } else { // Wilson's theorem res = value_type::mod() - 1; value_type den = 1; for (int i = value_type::mod() - 1; i > n; --i) { den *= value_type::raw(i); } res /= den; } return res; } }; #include using namespace std; using namespace atcoder; using mint = atcoder::modint998244353; int main() { std::ios::sync_with_stdio(false); std::cin.tie(nullptr); //int t; //std::cin >> t; //while (t--) { // int n; // std::cin >> n; // std::cout << factorial_large::fact(n).val() << '\n'; //} int N, K; cin >> N >> K; K = min(K, N - K); if (N - K >= 998244353) { modint ans = 1; for (int i = N - K + 1;i <= N; i++) { ans *= i; } modint div = 1; for (int i = 1;i <= K;i++) { div *= i; } ans *= div.inv(); cout << ans.val() << endl; } else if (N >= 998244353 && N - K <= 998244353) { cout << 0 << endl; } else { modint a = factorial_large::fact(int(N)).val(); modint b = factorial_large::fact(int(K)).val(); modint c = factorial_large::fact(int(N - K)).val(); b *= c; a *= b.inv(); cout << a.val() << endl; } return 0; }