結果
問題 |
No.2258 The Jikka Tree
|
ユーザー |
![]() |
提出日時 | 2025-03-06 14:55:15 |
言語 | Java (openjdk 23) |
結果 |
WA
|
実行時間 | - |
コード長 | 5,740 bytes |
コンパイル時間 | 4,435 ms |
コンパイル使用メモリ | 89,120 KB |
実行使用メモリ | 467,488 KB |
最終ジャッジ日時 | 2025-03-06 14:56:43 |
合計ジャッジ時間 | 83,047 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | WA * 1 |
other | AC * 3 WA * 27 TLE * 4 -- * 41 |
ソースコード
import java.util.*; public class Main { static final int MAXN = 150005; static final int MAXT = 4000000; static class Node { int l, r; long sum; int cnt; long eval(long k) { return sum + k * cnt; } } static Node[] tree = new Node[MAXT]; static int p; static int newNode() { return ++p; } static void pull(int p) { tree[p].sum = tree[tree[p].l].sum + tree[tree[p].r].sum; tree[p].cnt = tree[tree[p].l].cnt + tree[tree[p].r].cnt; } static void init(int s, int e, int p) { if (s == e) { return; } int m = (s + e) / 2; tree[p].l = newNode(); tree[p].r = newNode(); init(s, m, tree[p].l); init(m + 1, e, tree[p].r); } static void update(int pos, int s, int e, int r1, int r2, int v) { if (s == e) { tree[r2] = new Node(); tree[r2].sum = tree[r1].sum + v; tree[r2].cnt = tree[r1].cnt + 1; return; } int m = (s + e) / 2; if (pos <= m) { tree[r2].l = newNode(); tree[r2].r = tree[r1].r; update(pos, s, m, tree[r1].l, tree[r2].l, v); } else { tree[r2].l = tree[r1].l; tree[r2].r = newNode(); update(pos, m + 1, e, tree[r1].r, tree[r2].r, v); } pull(r2); } static List<Integer>[] gph = new ArrayList[MAXN]; static int[] din = new int[MAXN], dout = new int[MAXN], dep = new int[MAXN], rev = new int[MAXN]; static int[][] par = new int[18][MAXN]; static int piv; static void dfs(int x, int p) { din[x] = piv++; rev[din[x]] = x; for (int y : gph[x]) { if (y != p) { par[0][y] = x; dep[y] = dep[x] + 1; dfs(y, x); } } dout[x] = piv; } static boolean lessHalf(long[] sum, long tot) { if (sum[0] * 2 < tot) return true; if (sum[0] * 2 == tot && sum[1] == 0) return true; return false; } static int getMed(int s, int e, int p1, int p2, int k, int t, long[] curSum, long tot) { if (s == e) return s; int m = (s + e) / 2; long[] cmpSum = curSum.clone(); cmpSum[0] += tree[tree[p2].l].eval(k) - tree[tree[p1].l].eval(k); cmpSum[1] += (s <= t && t <= m ? 1 : 0); if (lessHalf(cmpSum, tot)) { return getMed(m + 1, e, tree[p1].r, tree[p2].r, k, t, cmpSum, tot); } return getMed(s, m, tree[p1].l, tree[p2].l, k, t, curSum, tot); } static long[] getSum(int s, int e, int ps, int pe, int p1, int p2, int k, int t) { if (e < ps || pe < s) return new long[]{0, 0}; if (s <= ps && pe <= e) { long sum = tree[p2].eval(k) - tree[p1].eval(k); return new long[]{sum, ps <= t && t <= pe ? 1 : 0}; } int pm = (ps + pe) / 2; long[] left = getSum(s, e, ps, pm, tree[p1].l, tree[p2].l, k, t); long[] right = getSum(s, e, pm + 1, pe, tree[p1].r, tree[p2].r, k, t); return new long[]{left[0] + right[0], left[1] + right[1]}; } static int n; static long[] asum = new long[MAXN]; static List<Integer> root = new ArrayList<>(); static int query(int l, int r, int k, int d) { long tot = asum[r + 1] - asum[l] + 1L * (r - l + 1) * k; int v = getMed(0, n - 1, root.get(l), root.get(r + 1), k, din[d], new long[]{0, 0}, tot); v = rev[v]; long[] sum = getSum(din[v], dout[v] - 1, 0, n - 1, root.get(l), root.get(r + 1), k, din[d]); if (!lessHalf(sum, tot)) return v; for (int i = 17; i >= 0; i--) { if (dep[v] >= (1 >> i)) { int anc = par[i][v]; long[] sumAnc = getSum(din[anc], dout[anc] - 1, 0, n - 1, root.get(l), root.get(r + 1), k, din[d]); if (lessHalf(sumAnc, tot)) v = par[i][v]; } } return par[0][v]; } public static void main(String[] args) { Scanner sc = new Scanner(System.in); n = sc.nextInt(); for (int i = 0; i < n; i++) { gph[i] = new ArrayList<>(); } for (int i = 0; i < n - 1; i++) { int u = sc.nextInt(), v = sc.nextInt(); gph[u].add(v); gph[v].add(u); } dfs(0, -1); for (int i = 1; i < 18; i++) { for (int j = 0; j < n; j++) { par[i][j] = par[i - 1][par[i - 1][j]]; } } int[] a = new int[n]; for (int i = 0; i < n; i++) { a[i] = sc.nextInt(); } for (int i = 0; i < MAXT; i++) { tree[i] = new Node(); } root.add(newNode()); init(0, n - 1, root.get(0)); for (int i = 1; i <= n; i++) { root.add(newNode()); update(din[i - 1], 0, n - 1, root.get(i - 1), root.get(i), a[i - 1]); asum[i] = asum[i - 1] + a[i - 1]; } int q = sc.nextInt(); long S = 0; while (q-- > 0) { long aa = sc.nextLong(), b = sc.nextLong(), z = sc.nextLong(), d = sc.nextLong(); aa += S; b += S * 2; z += (z + S * S % 150001) % 150001; aa %= n; b %= n; z %= 150001; if (aa > b) { long temp = aa; aa = b; b = temp; } int ans = query((int) aa, (int) b, (int) z, (int) d); System.out.println(ans); S += ans; } } }