結果
| 問題 |
No.2381 Gift Exchange Party
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2023-07-15 18:40:44 |
| 言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 30 ms / 2,000 ms |
| コード長 | 5,214 bytes |
| コンパイル時間 | 2,662 ms |
| コンパイル使用メモリ | 251,856 KB |
| 実行使用メモリ | 8,064 KB |
| 最終ジャッジ日時 | 2024-09-17 03:34:15 |
| 合計ジャッジ時間 | 3,523 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 22 |
ソースコード
#include <bits/stdc++.h>
using namespace std;
// right operation
// inplace
template <typename T>
void cumprod(vector<T> &a) {
int n = a.size();
for (int i = 0; i < n - 1; i++) {
a[i + 1] *= a[i];
}
}
template <typename T>
auto factorial(int n) -> vector<T> {
vector<T> a(n);
a[0] = 1;
for (int i = 1; i < n; ++i) {
a[i] = a[i - 1] * i;
}
return a;
}
template <typename T>
auto inv_factorial(int n) -> vector<T> {
vector<T> a(n);
a[n - 1] = 1 / factorial<T>(n)[n - 1];
for (int i = n; --i;) {
a[i - 1] = a[i] * i;
}
return a;
}
template <typename T>
class with_fact {
public:
vector<T> f, fi;
with_fact(int n)
: f(factorial<T>(n)),
fi(inv_factorial<T>(n)) {}
auto p(int n, int k) -> T {
return (k < 0 || n < k) ? 0 : f[n] * fi[n - k];
}
auto c(int n, int k) -> T {
return (k < 0 || n < k) ? 0 : p(n, k) * fi[k];
}
auto h(int n, int k) -> T { return c(n - 1 + k, k); }
auto inv(int n) -> T { return f[n - 1] * fi[n]; }
auto inv_p(int n, int k) -> T {
assert(0 <= k && k <= n);
return fi[n] * f[n - k];
}
auto inv_c(int n, int k) -> T { return inv_p(n, k) * f[k]; }
};
template <int p = -1>
class modint {
long v;
static int mod;
public:
static void set_mod(int m) { mod = m; }
constexpr static int m() { return p > 0 ? p : mod; }
constexpr modint(): v() {}
modint(long v): v(norm(v)) {}
static long norm(long x) {
if (x < -m() || x >= m()) {
x %= m();
}
if (x < 0) {
x += m();
}
return x;
}
int operator()() const { return v; }
modint operator-() const { return modint(m() - v); }
modint &operator+=(const modint &a) {
if ((v += a.v) >= m()) {
v -= m();
}
return *this;
}
modint &operator-=(const modint &a) { return *this += -a; }
modint &operator*=(const modint &a) {
v = norm(v * a.v);
return *this;
}
modint pow(long t) const {
if (t < 0) {
return pow(p - 2) * pow(-t);
}
if (t == 0) {
return 1;
}
modint a = pow(t >> 1);
a *= a;
if (t & 1) {
a *= *this;
}
return a;
}
modint inv() const { return pow(p - 2); }
modint &operator/=(const modint &a) {
return *this *= a.inv();
}
auto operator++() -> modint & { return *this += 1; }
auto operator--() -> modint & { return *this -= 1; }
auto operator++(int) -> modint {
modint a(*this);
*this += 1;
return a;
}
auto operator--(int) -> modint {
modint a(*this);
*this -= 1;
return a;
}
friend modint operator+(const modint &a, const modint &b) {
return modint(a) += b;
}
friend modint operator-(const modint &a, const modint &b) {
return modint(a) -= b;
}
friend modint operator*(const modint &a, const modint &b) {
return modint(a) *= b;
}
friend modint operator/(const modint &a, const modint &b) {
return modint(a) /= b;
}
friend bool operator==(const modint &a, const modint &b) {
return a.v == b.v;
}
friend istream &operator>>(istream &in, modint &x) {
in >> x.v;
x.v = norm(x.v);
return in;
}
friend ostream &operator<<(ostream &out, const modint &x) {
return out << x.v;
}
};
auto main() -> int {
// permutation functional graphの場合は
// どのノードもサイクルに含まれる
// サイクルサイズはN以下
// サイクルサイズがpの約数であるものが存在するとダメ。
// pの約数は1, pのみ
// p > n のときは1だけ考えれば良くて
// A_i = iが一つでもあるとダメ。
// p <= nのときは
// 1のときと、それプラス
// nこのうちp個からなるサイクルが1個以上あるような場合のかず。
// どっちもほうじょげんりでできそうだが。
// 1, pの場合の重複をどうやって省くか。
// g(n,p) = n! - f(n, p)
// f(3, 2) = 3C1*2! - 3C2*1! +3C3*0!
// + 3C2*1! - 3C2*1C1*0! (2一個と 1一個か二個)
// = 6 - 3 + 1 + 3 - 3 = 4
// g(3, 2) = 3! - 4 = 2
// 逆か
// 少なくとも一人同じだったらダメとしてたぇdp
// 少なくとも一人異なっていれば良い。
// 全員が同じになる場合を除けば良い
// どのサイクルも長さが1またはp
// pのサイクル数kを決め打つ。
// kこのいづれかのサイクルに入る人を選ぶ、
// 選ばれた人をp人ずつ分割する場合のかず
// 選ばれなかったn - k人は自己ループ確定
// nCk求めるやつ書くか。
int n, p;
cin >> n >> p;
using mint = modint<998'244'353>;
// using mint = modint<1'000'000'007>;
with_fact<mint> f(n + 1);
// mint cnt;
// for (int k = 0; k <= n / p; k++) {
// mint v = 1;
// int m = n;
// for (int i = 0; i < k; i++) {
// v *= f.c(m, p) * (f.f[p] - 1);
// m -= p;
// }
// cnt += v / f.fi[k];
// }
// cout << f.f[n] - cnt << '\n';
mint cnt = f.f[n];
for (int a = 0; a <= n; a++) {
int b = n - p * a;
if (b < 0) {
break;
}
if (a == 0) {
cnt -= 1;
continue;
}
mint x = f.f[n] * f.inv(p).pow(a);
x *= f.fi[a] * f.fi[b];
cnt -= x;
}
cout << cnt << '\n';
}