結果

問題 No.2258 The Jikka Tree
ユーザー Jaehyun Koo
提出日時 2023-06-04 21:41:26
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 3,670 ms / 4,000 ms
コード長 3,972 bytes
コンパイル時間 2,215 ms
コンパイル使用メモリ 199,028 KB
最終ジャッジ日時 2025-02-13 22:45:19
ジャッジサーバーID
(参考情報)
judge2 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 75
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
using lint = long long;
using pi = array<lint, 2>;
#define sz(v) ((int)(v).size())
#define all(v) (v).begin(), (v).end()
const int MAXN = 150005;
const int MAXT = 4000000;

struct node {
	int l, r;
	lint sum;
	int cnt;
	lint eval(lint k) { return sum + k * cnt; }
} tree[MAXT];

int p;
int newnode() { return ++p; }

void pull(int p) {
	tree[p].sum = tree[tree[p].l].sum + tree[tree[p].r].sum;
	tree[p].cnt = tree[tree[p].l].cnt + tree[tree[p].r].cnt;
}

void init(int s, int e, int p) {
	if (s == e) {
		return;
	}
	int m = (s + e) / 2;
	tree[p].l = newnode();
	tree[p].r = newnode();
	init(s, m, tree[p].l);
	init(m + 1, e, tree[p].r);
}

void update(int pos, int s, int e, int r1, int r2, int v) {
	if (s == e) {
		tree[r2] = tree[r1];
		tree[r2].sum += v;
		tree[r2].cnt += 1;
		return;
	}
	int m = (s + e) / 2;
	if (pos <= m) {
		tree[r2].l = newnode();
		tree[r2].r = tree[r1].r;
		update(pos, s, m, tree[r1].l, tree[r2].l, v);
	} else {
		tree[r2].l = tree[r1].l;
		tree[r2].r = newnode();
		update(pos, m + 1, e, tree[r1].r, tree[r2].r, v);
	}
	pull(r2);
}

vector<int> gph[MAXN];
int din[MAXN], dout[MAXN], par[18][MAXN], dep[MAXN], rev[MAXN], piv;

void dfs(int x, int p) {
	din[x] = piv++;
	rev[din[x]] = x;
	for (auto &y : gph[x]) {
		if (y != p) {
			par[0][y] = x;
			dep[y] = dep[x] + 1;
			dfs(y, x);
		}
	}
	dout[x] = piv;
}

bool lessHalf(pi sum, lint tot) {
	if (sum[0] * 2 < tot)
		return true;
	if (sum[0] * 2 == tot && sum[1] == 0)
		return true;
	return false;
}

int get_med(int s, int e, int p1, int p2, int k, int t, pi curSum, lint tot) {
	if (s == e)
		return s;
	int m = (s + e) / 2;
	pi cmpSum = curSum;
	cmpSum[0] += tree[tree[p2].l].eval(k) - tree[tree[p1].l].eval(k);
	cmpSum[1] += (s <= t && t <= m ? 1 : 0);
	if (lessHalf(cmpSum, tot)) {
		return get_med(m + 1, e, tree[p1].r, tree[p2].r, k, t, cmpSum, tot);
	}
	return get_med(s, m, tree[p1].l, tree[p2].l, k, t, curSum, tot);
}

pi operator+(const pi &a, const pi &b) { return pi{a[0] + b[0], a[1] + b[1]}; }

pi get_sum(int s, int e, int ps, int pe, int p1, int p2, int k, int t) {
	if (e < ps || pe < s)
		return pi{0, 0};
	if (s <= ps && pe <= e) {
		lint sum = tree[p2].eval(k) - tree[p1].eval(k);
		return pi{sum, ps <= t && t <= pe};
	}
	int pm = (ps + pe) / 2;
	return get_sum(s, e, ps, pm, tree[p1].l, tree[p2].l, k, t) + get_sum(s, e, pm + 1, pe, tree[p1].r, tree[p2].r, k, t);
}

int n;
lint asum[MAXN];
vector<int> root;

int query(int l, int r, int k, int d) {
	//	cout << l << " " << r + 1 << " " << k << " " << d << "\n";
	lint tot = asum[r + 1] - asum[l] + 1ll * (r - l + 1) * k;
	int v = get_med(0, n - 1, root[l], root[r + 1], k, din[d], pi{0, 0}, tot);
	v = rev[v];
	auto sum = get_sum(din[v], dout[v] - 1, 0, n - 1, root[l], root[r + 1], k, din[d]);
	if (!lessHalf(sum, tot))
		return v;
	for (int i = 17; i >= 0; i--) {
		if (dep[v] >= (1 >> i)) {
			int anc = par[i][v];
			auto sum = get_sum(din[anc], dout[anc] - 1, 0, n - 1, root[l], root[r + 1], k, din[d]);
			if (lessHalf(sum, tot))
				v = par[i][v];
		}
	}
	return par[0][v];
}

int main() {
	ios_base::sync_with_stdio(0);
	cin.tie(0);
	cout.tie(0);
	cin >> n;
	for (int i = 0; i < n - 1; i++) {
		int u, v;
		cin >> u >> v;
		gph[u].push_back(v);
		gph[v].push_back(u);
	}
	dfs(0, -1);
	for (int i = 1; i < 18; i++) {
		for (int j = 0; j < n; j++) {
			par[i][j] = par[i - 1][par[i - 1][j]];
		}
	}
	vector<int> a(n);
	for (auto &x : a)
		cin >> x;
	root.resize(n + 1);
	root[0] = newnode();
	init(0, n - 1, root[0]);
	for (int i = 1; i <= n; i++) {
		root[i] = newnode();
		update(din[i - 1], 0, n - 1, root[i - 1], root[i], a[i - 1]);
		asum[i] = asum[i - 1] + a[i - 1];
	}
	int q;
	cin >> q;
	lint S = 0;
	while (q--) {
		lint a, b, z, d;
		cin >> a >> b >> z >> d;
		a += S;
		b += S * 2;
		z += (__int128)S * S % 150001;
		a %= n;
		b %= n;
		z %= 150001;
		if (a > b)
			swap(a, b);
		int ans = query(a, b, z, d);
		cout << ans << "\n";
		S += ans;
	}
}
0