結果
| 問題 |
No.2294 Union Path Query (Easy)
|
| コンテスト | |
| ユーザー |
Kude
|
| 提出日時 | 2023-05-05 23:20:17 |
| 言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 190 ms / 4,000 ms |
| コード長 | 4,658 bytes |
| コンパイル時間 | 2,475 ms |
| コンパイル使用メモリ | 222,312 KB |
| 最終ジャッジ日時 | 2025-02-12 20:03:32 |
|
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 46 |
ソースコード
#include<bits/stdc++.h>
namespace {
#pragma GCC diagnostic ignored "-Wunused-function"
#include<atcoder/all>
#pragma GCC diagnostic warning "-Wunused-function"
using namespace std;
using namespace atcoder;
#define rep(i,n) for(int i = 0; i < (int)(n); i++)
#define rrep(i,n) for(int i = (int)(n) - 1; i >= 0; i--)
#define all(x) begin(x), end(x)
#define rall(x) rbegin(x), rend(x)
template<class T> bool chmax(T& a, const T& b) { if (a < b) { a = b; return true; } else return false; }
template<class T> bool chmin(T& a, const T& b) { if (b < a) { a = b; return true; } else return false; }
using ll = long long;
using P = pair<int,int>;
using VI = vector<int>;
using VVI = vector<VI>;
using VL = vector<ll>;
using VVL = vector<VL>;
// S : group
template <class S, S (*op)(S, S), S (*e)(), S (*inv)(S)>
struct weighted_union_find {
public:
weighted_union_find() : _n(0) {}
explicit weighted_union_find(int n) : _n(n), parent_or_size(n, -1), weight(n, e()) {}
int merge(int a, int b, S w) {
// W(a->b) = Wa^-1 Wb = w
assert(0 <= a && a < _n);
assert(0 <= b && b < _n);
int x = leader(a), y = leader(b);
assert(x != y);
if (-parent_or_size[x] < -parent_or_size[y]) {
std::swap(x, y);
std::swap(a, b);
w = inv(w);
}
// Wa^-1 Wy Wb = w
// Wy = Wa w Wb^-1
weight[y] = op(op(weight[a], w), inv(weight[b]));
parent_or_size[x] += parent_or_size[y];
parent_or_size[y] = x;
return x;
}
S diff(int a, int b) {
// W(a->b) = Wa^-1 Wb
int x = leader(a), y = leader(b);
assert(x == y);
return op(inv(weight[a]), weight[b]);
}
bool same(int a, int b) {
assert(0 <= a && a < _n);
assert(0 <= b && b < _n);
return leader(a) == leader(b);
}
int leader(int a) {
assert(0 <= a && a < _n);
int pre = -1;
while (parent_or_size[a] >= 0) {
int na = parent_or_size[a];
parent_or_size[a] = pre;
pre = a;
a = na;
}
S w = e();
while (pre != -1) {
w = op(w, weight[pre]);
weight[pre] = w;
int npre = parent_or_size[pre];
parent_or_size[pre] = a;
pre = npre;
}
return a;
}
int size(int a) {
assert(0 <= a && a < _n);
return -parent_or_size[leader(a)];
}
std::vector<std::vector<int>> groups() {
std::vector<int> leader_buf(_n), group_size(_n);
for (int i = 0; i < _n; i++) {
leader_buf[i] = leader(i);
group_size[leader_buf[i]]++;
}
std::vector<std::vector<int>> result(_n);
for (int i = 0; i < _n; i++) {
result[i].reserve(group_size[i]);
}
for (int i = 0; i < _n; i++) {
result[leader_buf[i]].push_back(i);
}
result.erase(
std::remove_if(result.begin(), result.end(),
[&](const std::vector<int>& v) { return v.empty(); }),
result.end());
return result;
}
private:
int _n;
// root node: -1 * component size
// otherwise: parent
std::vector<int> parent_or_size;
std::vector<S> weight;
};
int op(int x, int y) { return x ^ y; }
int e() { return 0; }
int inv(int x) { return x; }
using mint = modint998244353;
} int main() {
ios::sync_with_stdio(false);
cin.tie(0);
int n, x, q;
cin >> n >> x >> q;
struct S {
mint sm;
int cnt[30]{};
};
vector<S> d(n);
weighted_union_find<int, op, e, inv> uf(n);
vector<mint> pow2(30);
rep(i, 30) pow2[i] = mint(2).pow(i);
rep(_, q) {
int type;
cin >> type;
if (type == 1) {
int v, w;
cin >> v >> w;
// cout << "merge" << v << ' ' << x << endl;
int lv = uf.leader(v), lx = uf.leader(x);
int szv = uf.size(lv), szx = uf.size(lx);
int l = uf.merge(v, x, w);
S nd;
nd.sm += d[lv].sm + d[lx].sm;
int diff = uf.diff(lv, lx);
rep(k, 30) {
if (diff >> k & 1) nd.sm += pow2[k] * (ll(szv - d[lv].cnt[k]) * (szx - d[lx].cnt[k]) + ll(d[lv].cnt[k]) * d[lx].cnt[k]);
else nd.sm += pow2[k] * (ll(szv - d[lv].cnt[k]) * d[lx].cnt[k] + ll(d[lv].cnt[k]) * (szx - d[lx].cnt[k]));
}
rep(_, 2) {
int diff = uf.diff(lv, l);
rep(k, 30) {
if (diff >> k & 1) nd.cnt[k] += szv - d[lv].cnt[k];
else nd.cnt[k] += d[lv].cnt[k];
}
swap(lv, lx);
swap(szv, szx);
}
d[l] = nd;
} else if (type == 2) {
int u, v;
cin >> u >> v;
int res = !uf.same(u, v) ? -1 : uf.diff(u, v);
cout << res << '\n';
if (res != -1) x = (x + res) % n;
} else if (type == 3) {
int v;
cin >> v;
cout << d[uf.leader(v)].sm.val() << '\n';
} else {
int value;
cin >> value;
x = (x + value) % n;
}
}
}
Kude