結果

問題 No.2258 The Jikka Tree
ユーザー vjudge1
提出日時 2025-03-06 14:55:15
言語 Java
(openjdk 23)
結果
WA  
実行時間 -
コード長 5,740 bytes
コンパイル時間 4,435 ms
コンパイル使用メモリ 89,120 KB
実行使用メモリ 467,488 KB
最終ジャッジ日時 2025-03-06 14:56:43
合計ジャッジ時間 83,047 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample WA * 1
other AC * 3 WA * 27 TLE * 4 -- * 41
権限があれば一括ダウンロードができます

ソースコード

diff #

import java.util.*;

public class Main {
    static final int MAXN = 150005;
    static final int MAXT = 4000000;

    static class Node {
        int l, r;
        long sum;
        int cnt;

        long eval(long k) {
            return sum + k * cnt;
        }
    }

    static Node[] tree = new Node[MAXT];
    static int p;

    static int newNode() {
        return ++p;
    }

    static void pull(int p) {
        tree[p].sum = tree[tree[p].l].sum + tree[tree[p].r].sum;
        tree[p].cnt = tree[tree[p].l].cnt + tree[tree[p].r].cnt;
    }

    static void init(int s, int e, int p) {
        if (s == e) {
            return;
        }
        int m = (s + e) / 2;
        tree[p].l = newNode();
        tree[p].r = newNode();
        init(s, m, tree[p].l);
        init(m + 1, e, tree[p].r);
    }

    static void update(int pos, int s, int e, int r1, int r2, int v) {
        if (s == e) {
            tree[r2] = new Node();
            tree[r2].sum = tree[r1].sum + v;
            tree[r2].cnt = tree[r1].cnt + 1;
            return;
        }
        int m = (s + e) / 2;
        if (pos <= m) {
            tree[r2].l = newNode();
            tree[r2].r = tree[r1].r;
            update(pos, s, m, tree[r1].l, tree[r2].l, v);
        } else {
            tree[r2].l = tree[r1].l;
            tree[r2].r = newNode();
            update(pos, m + 1, e, tree[r1].r, tree[r2].r, v);
        }
        pull(r2);
    }

    static List<Integer>[] gph = new ArrayList[MAXN];
    static int[] din = new int[MAXN], dout = new int[MAXN], dep = new int[MAXN], rev = new int[MAXN];
    static int[][] par = new int[18][MAXN];
    static int piv;

    static void dfs(int x, int p) {
        din[x] = piv++;
        rev[din[x]] = x;
        for (int y : gph[x]) {
            if (y != p) {
                par[0][y] = x;
                dep[y] = dep[x] + 1;
                dfs(y, x);
            }
        }
        dout[x] = piv;
    }

    static boolean lessHalf(long[] sum, long tot) {
        if (sum[0] * 2 < tot)
            return true;
        if (sum[0] * 2 == tot && sum[1] == 0)
            return true;
        return false;
    }

    static int getMed(int s, int e, int p1, int p2, int k, int t, long[] curSum, long tot) {
        if (s == e)
            return s;
        int m = (s + e) / 2;
        long[] cmpSum = curSum.clone();
        cmpSum[0] += tree[tree[p2].l].eval(k) - tree[tree[p1].l].eval(k);
        cmpSum[1] += (s <= t && t <= m ? 1 : 0);
        if (lessHalf(cmpSum, tot)) {
            return getMed(m + 1, e, tree[p1].r, tree[p2].r, k, t, cmpSum, tot);
        }
        return getMed(s, m, tree[p1].l, tree[p2].l, k, t, curSum, tot);
    }

    static long[] getSum(int s, int e, int ps, int pe, int p1, int p2, int k, int t) {
        if (e < ps || pe < s)
            return new long[]{0, 0};
        if (s <= ps && pe <= e) {
            long sum = tree[p2].eval(k) - tree[p1].eval(k);
            return new long[]{sum, ps <= t && t <= pe ? 1 : 0};
        }
        int pm = (ps + pe) / 2;
        long[] left = getSum(s, e, ps, pm, tree[p1].l, tree[p2].l, k, t);
        long[] right = getSum(s, e, pm + 1, pe, tree[p1].r, tree[p2].r, k, t);
        return new long[]{left[0] + right[0], left[1] + right[1]};
    }

    static int n;
    static long[] asum = new long[MAXN];
    static List<Integer> root = new ArrayList<>();

    static int query(int l, int r, int k, int d) {
        long tot = asum[r + 1] - asum[l] + 1L * (r - l + 1) * k;
        int v = getMed(0, n - 1, root.get(l), root.get(r + 1), k, din[d], new long[]{0, 0}, tot);
        v = rev[v];
        long[] sum = getSum(din[v], dout[v] - 1, 0, n - 1, root.get(l), root.get(r + 1), k, din[d]);
        if (!lessHalf(sum, tot))
            return v;
        for (int i = 17; i >= 0; i--) {
            if (dep[v] >= (1 >> i)) {
                int anc = par[i][v];
                long[] sumAnc = getSum(din[anc], dout[anc] - 1, 0, n - 1, root.get(l), root.get(r + 1), k, din[d]);
                if (lessHalf(sumAnc, tot))
                    v = par[i][v];
            }
        }
        return par[0][v];
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        n = sc.nextInt();
        for (int i = 0; i < n; i++) {
            gph[i] = new ArrayList<>();
        }
        for (int i = 0; i < n - 1; i++) {
            int u = sc.nextInt(), v = sc.nextInt();
            gph[u].add(v);
            gph[v].add(u);
        }
        dfs(0, -1);
        for (int i = 1; i < 18; i++) {
            for (int j = 0; j < n; j++) {
                par[i][j] = par[i - 1][par[i - 1][j]];
            }
        }
        int[] a = new int[n];
        for (int i = 0; i < n; i++) {
            a[i] = sc.nextInt();
        }
        for (int i = 0; i < MAXT; i++) {
            tree[i] = new Node();
        }
        root.add(newNode());
        init(0, n - 1, root.get(0));
        for (int i = 1; i <= n; i++) {
            root.add(newNode());
            update(din[i - 1], 0, n - 1, root.get(i - 1), root.get(i), a[i - 1]);
            asum[i] = asum[i - 1] + a[i - 1];
        }
        int q = sc.nextInt();
        long S = 0;
        while (q-- > 0) {
            long aa = sc.nextLong(), b = sc.nextLong(), z = sc.nextLong(), d = sc.nextLong();
            aa += S;
            b += S * 2;
            z += (z + S * S % 150001) % 150001;
            aa %= n;
            b %= n;
            z %= 150001;
            if (aa > b) {
                long temp = aa;
                aa = b;
                b = temp;
            }
            int ans = query((int) aa, (int) b, (int) z, (int) d);
            System.out.println(ans);
            S += ans;
        }
    }
}
0