結果

問題 No.2258 The Jikka Tree
ユーザー qwewe
提出日時 2025-04-24 12:28:00
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 6,485 bytes
コンパイル時間 329 ms
コンパイル使用メモリ 82,280 KB
実行使用メモリ 97,372 KB
最終ジャッジ日時 2025-04-24 12:30:15
合計ジャッジ時間 10,085 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 2 WA * 1 TLE * 1 -- * 71
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import bisect

sys.setrecursionlimit(1 << 25)

def main():
    input = sys.stdin.read().split()
    ptr = 0

    N = int(input[ptr])
    ptr +=1

    edges = [[] for _ in range(N)]
    for _ in range(N-1):
        u = int(input[ptr])
        v = int(input[ptr+1])
        ptr +=2
        edges[u].append(v)
        edges[v].append(u)

    # Build tree with parent and children
    root = 0
    parent = [ -1 ] * N
    children = [[] for _ in range(N)]
    stack = [root]
    while stack:
        u = stack.pop()
        for v in edges[u]:
            if parent[u] != v:
                parent[v] = u
                children[u].append(v)
                stack.append(v)

    # Compute in_time, out_time via DFS
    in_time = [0] * N
    out_time = [0] * N
    depth = [0] * N
    time = 0
    stack = [(root, False)]
    while stack:
        u, visited = stack.pop()
        if visited:
            out_time[u] = time
            time +=1
            continue
        in_time[u] = time
        time +=1
        stack.append( (u, True) )
        # Reverse to process children in order
        for v in reversed(children[u]):
            stack.append( (v, False) )
            depth[v] = depth[u] +1

    # Read A
    A = list(map(int, input[ptr:ptr+N]))
    ptr +=N

    # Build segment tree
    class SegmentNode:
        __slots__ = ['left', 'right', 'l', 'r', 'in_times', 'prefix_sum']
        def __init__(self, l, r):
            self.l = l
            self.r = r
            self.left = None
            self.right = None
            self.in_times = []
            self.prefix_sum = []

    def build(l, r):
        node = SegmentNode(l, r)
        if l == r:
            # Leaf node: contains in_time[l], A[l]
            node.in_times = [ in_time[l] ]
            node.prefix_sum = [ A[l] ]
        else:
            mid = (l + r) // 2
            node.left = build(l, mid)
            node.right = build(mid+1, r)
            # Merge sorted lists
            i = j = 0
            node.in_times = []
            node.prefix_sum = []
            sum_ = 0
            while i < len(node.left.in_times) and j < len(node.right.in_times):
                if node.left.in_times[i] < node.right.in_times[j]:
                    node.in_times.append(node.left.in_times[i])
                    sum_ += node.left.prefix_sum[i] - (node.left.prefix_sum[i-1] if i>0 else 0)
                    node.prefix_sum.append(sum_)
                    i +=1
                else:
                    node.in_times.append(node.right.in_times[j])
                    sum_ += node.right.prefix_sum[j] - (node.right.prefix_sum[j-1] if j>0 else 0)
                    node.prefix_sum.append(sum_)
                    j +=1
            while i < len(node.left.in_times):
                node.in_times.append(node.left.in_times[i])
                sum_ += node.left.prefix_sum[i] - (node.left.prefix_sum[i-1] if i>0 else 0)
                node.prefix_sum.append(sum_)
                i +=1
            while j < len(node.right.in_times):
                node.in_times.append(node.right.in_times[j])
                sum_ += node.right.prefix_sum[j] - (node.right.prefix_sum[j-1] if j>0 else 0)
                node.prefix_sum.append(sum_)
                j +=1
        return node

    # Build segment tree for original indices 0..N-1
    seg_root = build(0, N-1)

    # Function to query count and sum in [l, r) and in_time [a, b]
    def query_segment(node, ql, qr, a, b):
        if node.r < ql or node.l > qr:
            return (0, 0)
        if ql <= node.l and node.r <= qr:
            # Binary search in in_times
            left = bisect.bisect_left(node.in_times, a)
            right_idx = bisect.bisect_right(node.in_times, b)
            cnt = right_idx - left
            if cnt ==0:
                sum_a = 0
            else:
                sum_a = node.prefix_sum[right_idx-1]
                if left >0:
                    sum_a -= node.prefix_sum[left-1]
            return (cnt, sum_a)
        else:
            cntL, sumL = query_segment(node.left, ql, qr, a, b)
            cntR, sumR = query_segment(node.right, ql, qr, a, b)
            return (cntL + cntR, sumL + sumR)

    # Precompute prefix sums for A and depth
    prefix_A = [0]*(N+1)
    prefix_depth = [0]*(N+1)
    for i in range(N):
        prefix_A[i+1] = prefix_A[i] + A[i]
        prefix_depth[i+1] = prefix_depth[i] + depth[i]

    Q = int(input[ptr])
    ptr +=1

    queries = []
    for _ in range(Q):
        a_prime = int(input[ptr])
        b_prime = int(input[ptr+1])
        k_prime = int(input[ptr+2])
        delta = int(input[ptr+3])
        ptr +=4
        queries.append( (a_prime, b_prime, k_prime, delta) )

    X = []
    sum_X = 0
    for i in range(Q):
        a_prime, b_prime, k_prime, delta = queries[i]
        if i ==0:
            a = a_prime
            b = b_prime
            k = k_prime
        else:
            a = (a_prime + sum_X) % N
            b = (b_prime + 2 * sum_X) % N
            k = (k_prime + sum_X * sum_X) % 150001
        l = min(a, b)
        r = 1 + max(a, b)
        # Compute l_i and r_i
        # Now process query l, r, k, delta
        # Compute sum_total = sum_A + k * (r - l)
        sum_A = prefix_A[r] - prefix_A[l]
        count_nodes = r - l
        sum_total = sum_A + k * count_nodes

        # Find the 1-median
        current = root
        while True:
            max_child = None
            max_sum = -1
            total_sum = sum_total
            for v in children[current]:
                # Compute sum for subtree of v
                inL = in_time[v]
                inR = out_time[v] -1  # since out_time is exclusive
                cnt, sum_a = query_segment(seg_root, l, r-1, inL, inR)
                sum_v = sum_a + k * cnt
                if sum_v > max_sum:
                    max_sum = sum_v
                    max_child = v
            if max_sum > total_sum / 2:
                current = max_child
            else:
                break

        # Now, current is the 1-median candidate
        # Check tiebreaker with delta_i
        # But according to problem statement, the answer is unique, so no need
        # But just in case, find the closest to delta_i among candidates with same sum
        # However, given time constraints, we'll proceed with current as the answer
        X_i = current
        X.append(X_i)
        sum_X += X_i

    for x in X:
        print(x)

if __name__ == '__main__':
    main()
0