結果
| 問題 |
No.2108 Red or Blue and Purple Tree
|
| ユーザー |
hitonanode
|
| 提出日時 | 2022-09-20 21:28:02 |
| 言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 2,529 ms / 4,000 ms |
| コード長 | 4,864 bytes |
| コンパイル時間 | 11,543 ms |
| コンパイル使用メモリ | 324,316 KB |
| 実行使用メモリ | 76,552 KB |
| 最終ジャッジ日時 | 2024-12-22 03:41:28 |
| 合計ジャッジ時間 | 35,985 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 7 |
ソースコード
// n <= 7 まで愚直解と突き合わせる
#include <algorithm>
#include <cassert>
#include <iostream>
#include <utility>
#include <vector>
using namespace std;
#include "testlib.h"
#include <atcoder/modint>
#include <atcoder/convolution>
using mint = atcoder::modint998244353;
template <typename modint> struct acl_fac {
std::vector<modint> facs, facinvs;
acl_fac(int N) {
assert(-1 <= N and N < modint::mod());
facs.resize(N + 1, 1);
for (int i = 1; i <= N; i++) facs[i] = facs[i - 1] * i;
facinvs.assign(N + 1, facs.back().inv());
for (int i = N; i > 0; i--) facinvs[i - 1] = facinvs[i] * i;
}
modint operator[](int i) const { return facs.at(i); }
modint finv(int i) const { return facinvs.at(i); }
};
constexpr int max_n = 2000;
acl_fac<mint> fac(max_n);
#include <numeric>
#include <stack>
#include <utility>
#include <vector>
// UnionFind, able to rewind to the previous state by undo()
// Written for Educational Codeforces 62 F, although not verified yet.
struct UndoUnionFind {
using pint = std::pair<int, int>;
std::vector<int> par, cou;
std::stack<std::pair<int, pint>> history;
UndoUnionFind(int N) : par(N), cou(N, 1) { std::iota(par.begin(), par.end(), 0); }
int find(int x) const { return (par[x] == x) ? x : find(par[x]); }
bool unite(int x, int y) {
x = find(x), y = find(y);
if (cou[x] < cou[y]) std::swap(x, y);
history.emplace(y, pint(par[y], cou[x]));
return x != y ? par[y] = x, cou[x] += cou[y], true : false;
}
void undo() {
cou[par[history.top().first]] = history.top().second.second;
par[history.top().first] = history.top().second.first;
history.pop();
}
void reset() {
while (!history.empty()) undo();
}
int count(int x) const { return cou[find(x)]; }
bool same(int x, int y) const { return find(x) == find(y); }
};
vector<vector<mint>> gen_bruteforce_table(int maxn) {
vector<vector<mint>> ret(1);
for (int n = 1; n <= maxn; ++n) {
vector<pair<int, int>> edges;
for (int i = 0; i < n; ++i) {
for (int j = i; j < n; ++j) edges.emplace_back(i, j);
}
vector<int> tree_masks;
UndoUnionFind uf(n);
auto rec = [&](auto &&self, int d, int m) -> void {
if (uf.count(0) == n) {
tree_masks.push_back(m);
return;
}
if (d == int(edges.size())) return;
self(self, d + 1, m);
auto [s, t] = edges.at(d);
if (!uf.same(s, t)) {
uf.unite(s, t);
self(self, d + 1, m + (1 << d));
uf.undo();
}
};
rec(rec, 0, 0);
vector<mint> tmp(n);
for (auto m1 : tree_masks) {
for (auto m2 : tree_masks) {
int k = __builtin_popcount(m1 & m2);
tmp.at(k) += 1;
}
}
ret.push_back(tmp);
}
return ret;
}
int main(int argc, char *argv[]) {
registerValidation(argc, argv);
vector<mint> f(max_n + 1);
for (int n = 1; n <= max_n; ++n) {
f.at(n) = (n == 1 ? 1 : mint(n).pow(n - 2)) * n * n * fac.finv(n);
}
vector<vector<mint>> dp(max_n + 1);
dp.at(0) = vector<mint>(max_n + 1);
dp.at(0).at(0) = 1;
for (int k = 1; k <= max_n; ++k) {
dp.at(k) = atcoder::convolution(f, dp.at(k - 1));
dp.at(k).resize(max_n + 1);
}
for (int k = 1; k <= max_n; ++k) {
for (int n = k; n <= max_n; ++n) {
if (k > 1) {
dp.at(k).at(n) *= fac[n] * fac.finv(k) * mint(n).pow(2 * (k - 2));
} else {
dp.at(k).at(n) = (n > 1 ? mint(n).pow(n - 2) : 1);
}
}
}
vector<vector<mint>> answers(1);
for (int n = 1; n <= max_n; ++n) {
vector<mint> g(n);
for (int i = 0; i < n; ++i) g.at(i) = dp.at(n - i).at(n) * fac[i] * (i % 2 ? -1 : 1);
vector<mint> fac_trans(n);
for (int i = 0; i < n; ++i) fac_trans.at(i) = fac.finv(i);
reverse(g.begin(), g.end());
g = atcoder::convolution(g, fac_trans);
g.resize(n);
reverse(g.begin(), g.end());
for (int i = 0; i < n; ++i) g.at(i) *= fac.finv(i) * (i % 2 ? -1 : 1);
answers.push_back(g);
}
vector<vector<mint>> bf_answers = gen_bruteforce_table(7);
for (int n = 2; n <= 7; ++n) {
for (int k = 0; k <= n - 1; ++k) assert(answers.at(n).at(k) == bf_answers.at(n).at(k));
}
int T = inf.readInt(1, 500000);
inf.readEoln();
while (T--) {
const int n = inf.readInt(2, max_n);
inf.readSpace();
const int k = inf.readInt(0, n - 1);
inf.readEoln();
cout << answers.at(n).at(k).val() << '\n';
}
inf.readEof();
}
hitonanode