結果

問題 No.2258 The Jikka Tree
ユーザー qwewe
提出日時 2025-05-14 12:51:47
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 6,617 bytes
コンパイル時間 493 ms
コンパイル使用メモリ 82,816 KB
実行使用メモリ 85,504 KB
最終ジャッジ日時 2025-05-14 12:53:09
合計ジャッジ時間 10,583 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 2 WA * 1 TLE * 1 -- * 71
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import bisect
from sys import stdin

def main():
    sys.setrecursionlimit(1 << 25)
    input = sys.stdin.read().split()
    ptr = 0

    N = int(input[ptr])
    ptr +=1

    edges = [[] for _ in range(N)]
    for _ in range(N-1):
        u = int(input[ptr])
        v = int(input[ptr+1])
        ptr +=2
        edges[u].append(v)
        edges[v].append(u)

    # Build children and parent using BFS
    from collections import deque
    root = 0
    parent = [-1]*N
    children = [[] for _ in range(N)]
    visited = [False]*N
    q = deque([root])
    visited[root] = True
    while q:
        u = q.popleft()
        for v in edges[u]:
            if not visited[v]:
                visited[v] = True
                parent[v] = u
                children[u].append(v)
                q.append(v)

    # Pre-order traversal to assign in_time and out_time
    in_time = [0]*N
    out_time = [0]*N
    time = 0
    stack = [(root, False)]
    while stack:
        node, is_processed = stack.pop()
        if is_processed:
            out_time[node] = time
            time +=1
        else:
            in_time[node] = time
            time +=1
            stack.append( (node, True) )
            # Push children in reverse order to process them in order
            for child in reversed(children[node]):
                stack.append( (child, False) )

    # Read A array
    A = list(map(int, input[ptr:ptr+N]))
    ptr +=N

    # Build prefix sum for A
    prefix_A = [0]*(N+1)
    for i in range(N):
        prefix_A[i+1] = prefix_A[i] + A[i]

    # Build segment tree
    class SegmentTreeNode:
        __slots__ = ['l', 'r', 'left', 'right', 'data', 'prefix_sums']
        def __init__(self, l, r):
            self.l = l
            self.r = r
            self.left = None
            self.right = None
            self.data = []  # list of (in_time, A) sorted by in_time
            self.prefix_sums = []

    def build(l, r):
        node = SegmentTreeNode(l, r)
        if l == r:
            # Leaf node: original index l
            in_t = in_time[l]
            a = A[l]
            node.data = [ (in_t, a) ]
            node.prefix_sums = [0, a]
        else:
            mid = (l + r) // 2
            node.left = build(l, mid)
            node.right = build(mid+1, r)
            # Merge data from left and right
            i = j = 0
            merged = []
            while i < len(node.left.data) and j < len(node.right.data):
                if node.left.data[i][0] < node.right.data[j][0]:
                    merged.append(node.left.data[i])
                    i +=1
                else:
                    merged.append(node.right.data[j])
                    j +=1
            merged.extend(node.left.data[i:])
            merged.extend(node.right.data[j:])
            node.data = merged
            # Compute prefix_sums
            node.prefix_sums = [0]
            current_sum = 0
            for in_t, a in node.data:
                current_sum += a
                node.prefix_sums.append(current_sum)
        return node

    # Build the segment tree
    root_segment = build(0, N-1)

    # Query function
    def query_segment(node, l, r, a_low, a_high):
        if node.r < l or node.l > r:
            return (0, 0)
        if l <= node.l and node.r <= r:
            # Binary search for a_low <= in_time <= a_high
            left = bisect.bisect_left(node.data, (a_low, -1))
            right_idx = bisect.bisect_right(node.data, (a_high, float('inf')))
            count = right_idx - left
            sum_A = node.prefix_sums[right_idx] - node.prefix_sums[left]
            return (count, sum_A)
        else:
            left_count, left_sum = query_segment(node.left, l, r, a_low, a_high)
            right_count, right_sum = query_segment(node.right, l, r, a_low, a_high)
            return (left_count + right_count, left_sum + right_sum)

    Q = int(input[ptr])
    ptr +=1

    # Process queries
    X = []
    sum_X = 0
    for _ in range(Q):
        a_prime = int(input[ptr])
        b_prime = int(input[ptr+1])
        k_prime = int(input[ptr+2])
        delta_i = int(input[ptr+3])
        ptr +=4

        # Compute a_i, b_i, k_i
        if len(X) == 0:
            a_i = a_prime
            b_i = b_prime
            k_i = k_prime
        else:
            a_i = (a_prime + sum_X) % N
            b_i = (b_prime + 2 * sum_X) % N
            k_i = (k_prime + (sum_X * sum_X)) % 150001

        l_i = min(a_i, b_i)
        r_i = max(a_i, b_i) + 1
        if r_i > N:
            r_i = N

        # Compute sum_A and count
        sum_A = prefix_A[r_i] - prefix_A[l_i]
        count = r_i - l_i
        total_sum = sum_A + k_i * count

        # Find centroid
        current_node = root
        parent_node = None
        while True:
            # Compute sum of current_node's subtree
            in_current = in_time[current_node]
            out_current = out_time[current_node]
            cnt_current, sum_A_current = query_segment(root_segment, l_i, r_i-1, in_current, out_current-1)
            sum_current = sum_A_current + k_i * cnt_current

            # Check parent's subtree (total_sum - sum_current)
            sum_parent = total_sum - sum_current
            if sum_parent > total_sum / 2:
                # Move to parent
                next_node = parent_node
                parent_node = current_node
                current_node = next_node
                if current_node is None:
                    break
                continue

            # Check children
            max_child_sum = -1
            selected_child = None
            for child in children[current_node]:
                if parent_node == child:
                    continue
                in_child = in_time[child]
                out_child = out_time[child]
                cnt_child, sum_A_child = query_segment(root_segment, l_i, r_i-1, in_child, out_child-1)
                sum_child = sum_A_child + k_i * cnt_child
                if sum_child > max_child_sum:
                    max_child_sum = sum_child
                    selected_child = child

            if max_child_sum > total_sum / 2:
                parent_node = current_node
                current_node = selected_child
            else:
                break

        X_i = current_node
        X.append(X_i)
        sum_X += X_i

        # Apply tiebreaker with delta_i (though problem says it's unique)
        # Since the problem states the answer is unique, we can ignore this part

    # Output all X
    for x in X:
        print(x)

if __name__ == '__main__':
    main()
0