結果

問題 No.2258 The Jikka Tree
ユーザー gew1fw
提出日時 2025-06-12 17:11:34
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 4,072 bytes
コンパイル時間 164 ms
コンパイル使用メモリ 82,796 KB
実行使用メモリ 94,432 KB
最終ジャッジ日時 2025-06-12 17:11:53
合計ジャッジ時間 14,257 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 6 WA * 2 TLE * 1 -- * 66
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
sys.setrecursionlimit(1 << 25)

def main():
    import sys
    input = sys.stdin.read().split()
    ptr = 0

    def read_int():
        nonlocal ptr
        res = int(input[ptr])
        ptr += 1
        return res

    N = read_int()
    edges = [[] for _ in range(N)]
    for _ in range(N-1):
        u = read_int()
        v = read_int()
        edges[u].append(v)
        edges[v].append(u)
    A = list(map(int, input[ptr:ptr+N]))
    ptr += N
    Q = read_int()
    queries = []
    for _ in range(Q):
        a = read_int()
        b = read_int()
        k = read_int()
        delta = read_int()
        queries.append( (a, b, k, delta) )
    X = []

    # Preprocess subtree lists
    in_time = [0] * N
    out_time = [0] * N
    time = 0
    parent = [-1] * N
    children = [[] for _ in range(N)]
    stack = [(0, False)]
    while stack:
        node, visited = stack.pop()
        if visited:
            out_time[node] = time
            time += 1
            continue
        in_time[node] = time
        time += 1
        stack.append( (node, True) )
        for neighbor in edges[node]:
            if neighbor != parent[node]:
                parent[neighbor] = node
                children[node].append(neighbor)
                stack.append( (neighbor, False) )
    
    # For each node u, store the list of nodes in its subtree, sorted by node index
    subtree = [[] for _ in range(N)]
    visited = [False] * N
    stack = []
    post_order = []
    stack.append( (0, False) )
    while stack:
        node, done = stack.pop()
        if done:
            post_order.append(node)
            continue
        stack.append( (node, True) )
        for child in reversed(children[node]):
            stack.append( (child, False) )
    for u in post_order:
        subtree[u].append(u)
        for child in children[u]:
            subtree[u].extend(subtree[child])
        subtree[u].sort()
    
    for q in range(Q):
        a_prime = queries[q][0]
        b_prime = queries[q][1]
        k_prime = queries[q][2]
        delta = queries[q][3]
        
        if q == 0:
            a = a_prime
            b = b_prime
            k = k_prime
        else:
            sum_x = sum(X[:q])
            a = (a_prime + sum_x) % N
            b = (b_prime + 2 * sum_x) % N
            k = (k_prime + (sum_x ** 2)) % 150001
        
        l = min(a, b)
        r = max(a, b) + 1
        
        # Compute the total weight T
        T = 0
        for w in range(l, r):
            T += A[w] + k
        
        if T == 0:
            X.append(delta)
            continue
        
        # Find the optimal v using centroid decomposition
        current = 0
        while True:
            max_sum = 0
            child_to_go = -1
            for child in children[current]:
                # Compute sum of (A[w] + k) for w in [l, r) and in child's subtree
                nodes = subtree[child]
                left = 0
                right = len(nodes)
                while left < right:
                    mid = (left + right) // 2
                    if nodes[mid] < l:
                        left = mid + 1
                    else:
                        right = mid
                idx_l = left
                left = 0
                right = len(nodes)
                while left < right:
                    mid = (left + right) // 2
                    if nodes[mid] < r:
                        left = mid + 1
                    else:
                        right = mid
                idx_r = left
                cnt = idx_r - idx_l
                sum_ = 0
                for i in range(idx_l, idx_r):
                    w = nodes[i]
                    sum_ += A[w] + k
                if sum_ > max_sum:
                    max_sum = sum_
                    child_to_go = child
            if max_sum * 2 > T:
                current = child_to_go
            else:
                break
        X.append(current)
    
    for x in X:
        print(x)

if __name__ == "__main__":
    main()
0