結果

問題 No.2258 The Jikka Tree
ユーザー lam6er
提出日時 2025-03-20 18:57:55
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 4,136 bytes
コンパイル時間 344 ms
コンパイル使用メモリ 82,348 KB
実行使用メモリ 97,536 KB
最終ジャッジ日時 2025-03-20 18:59:07
合計ジャッジ時間 17,976 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 6 WA * 2 TLE * 1 -- * 66
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from bisect import bisect_left, bisect_right

def main():
    sys.setrecursionlimit(1 << 25)
    input = sys.stdin.read().split()
    ptr = 0

    N = int(input[ptr]); ptr +=1
    edges = [[] for _ in range(N)]
    for _ in range(N-1):
        u = int(input[ptr]); ptr +=1
        v = int(input[ptr]); ptr +=1
        edges[u].append(v)
        edges[v].append(u)
    A = list(map(int, input[ptr:ptr+N])); ptr += N
    Q = int(input[ptr]); ptr +=1
    queries = []
    for _ in range(Q):
        a_prime = int(input[ptr]); ptr +=1
        b_prime = int(input[ptr]); ptr +=1
        k_prime = int(input[ptr]); ptr +=1
        delta = int(input[ptr]); ptr +=1
        queries.append((a_prime, b_prime, k_prime, delta))

    in_time = [0]*N
    out_time = [0]*N
    depth = [0]*N
    parent = [ -1 ] * N
    children = [[] for _ in range(N)]
    time = 0
    order = []
    stack = [(0, False, -1)]
    while stack:
        v, visited, p = stack.pop()
        if visited:
            out_time[v] = time
            continue
        parent[v] = p
        if p != -1:
            children[v] = [u for u in edges[v] if u != p]
        else:
            children[v] = edges[v][:]
        in_time[v] = time
        time += 1
        order.append(v)
        stack.append((v, True, p))
        for u in reversed(children[v]):
            if u == p:
                continue
            stack.append((u, False, v))
    in_order_to_original = order

    subtree_original_indices = [[] for _ in range(N)]
    for v in range(N):
        l = in_time[v]
        r = out_time[v]
        for original in order[l:r]:
            subtree_original_indices[v].append(original)
        subtree_original_indices[v].sort()

    A_prefix = [0]*(N+1)
    for i in range(N):
        A_prefix[i+1] = A_prefix[i] + A[i]

    count_prefix = [0]*(N+1)
    for i in range(N):
        count_prefix[i+1] = count_prefix[i] + 1

    sumA_subtree = []
    sorted_subtree_original = []
    for v in range(N):
        sorted_sub = sorted(subtree_original_indices[v])
        sorted_subtree_original.append(sorted_sub)
        sumA = [0]*(len(sorted_sub)+1)
        for i in range(len(sorted_sub)):
            sumA[i+1] = sumA[i] + A[sorted_sub[i]]
        sumA_subtree.append(sumA)

    def get_sum_and_count_in_subtree(v, l, r):
        sub = sorted_subtree_original[v]
        left = bisect_left(sub, l)
        right_idx = bisect_left(sub, r)
        count = right_idx - left
        sumA = sumA_subtree[v][right_idx] - sumA_subtree[v][left]
        return sumA, count

    X = []
    total_X = 0

    for q in queries:
        a_prime, b_prime, k_prime, delta = q
        if not X:
            a = a_prime
            b = b_prime
            k = k_prime
        else:
            a = (a_prime + total_X) % N
            b = (b_prime + 2 * total_X) % N
            k = (k_prime + (total_X) * total_X) % 150001
        l = min(a, b)
        r = 1 + max(a, b)
        sum_A = A_prefix[r] - A_prefix[l]
        count = r - l
        total_sum = sum_A + k * count
        current = 0
        while True:
            max_child = -1
            max_sum = -1
            for u in children[current]:
                sumA_u, cnt_u = get_sum_and_count_in_subtree(u, l, r)
                current_sum = sumA_u + k * cnt_u
                if current_sum > max_sum:
                    max_sum = current_sum
                    max_child = u
            if parent[current] != -1:
                sum_parent_subtree = sum_A + k * count
                sumA_p, cnt_p = get_sum_and_count_in_subtree(parent[current], l, r)
                sumA_exclude, cnt_exclude = get_sum_and_count_in_subtree(current, l, r)
                sum_parent_side = sum_parent_subtree - sumA_exclude - k * cnt_exclude
                if sum_parent_side > max_sum:
                    max_sum = sum_parent_side
                    max_child = parent[current]
            if max_sum > total_sum / 2:
                current = max_child
            else:
                break
        X.append(current)
        total_X += current

    for x in X:
        print(x)

if __name__ == "__main__":
    main()
0