結果

問題 No.2258 The Jikka Tree
ユーザー gew1fw
提出日時 2025-06-12 17:18:13
言語 PyPy3
(7.3.15)
結果
RE  
実行時間 -
コード長 3,431 bytes
コンパイル時間 391 ms
コンパイル使用メモリ 82,304 KB
実行使用メモリ 86,308 KB
最終ジャッジ日時 2025-06-12 17:18:26
合計ジャッジ時間 12,508 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 1 WA * 1 RE * 1 TLE * 1 -- * 71
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
from math import log2, ceil
sys.setrecursionlimit(1 << 25)

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr +=1

    edges = [[] for _ in range(N)]
    for _ in range(N-1):
        u = int(input[ptr])
        v = int(input[ptr+1])
        ptr +=2
        edges[u].append(v)
        edges[v].append(u)

    A = list(map(int, input[ptr:ptr+N]))
    ptr +=N

    LOG = ceil(log2(N)) if N >0 else 1
    parent = [[-1]*N for _ in range(LOG)]
    depth = [0]*N

    stack = [(0, -1)]
    while stack:
        u, p = stack.pop()
        parent[0][u] = p
        for v in edges[u]:
            if v != p:
                depth[v] = depth[u]+1
                stack.append((v, u))

    for k in range(1, LOG):
        for v in range(N):
            if parent[k-1][v] != -1:
                parent[k][v] = parent[k-1][parent[k-1][v]]
            else:
                parent[k][v] = -1

    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        for k in range(LOG-1, -1, -1):
            if depth[u] - (1<<k) >= depth[v]:
                u = parent[k][u]
        if u == v:
            return u
        for k in range(LOG-1, -1, -1):
            if parent[k][u] != -1 and parent[k][u] != parent[k][v]:
                u = parent[k][u]
                v = parent[k][v]
        return parent[0][u]

    def dist(u, v):
        ancestor = lca(u, v)
        return depth[u] + depth[v] - 2*depth[ancestor]

    prefixA = [0]*(N+1)
    for i in range(N):
        prefixA[i+1] = prefixA[i] + A[i]

    prefixW = [0]*(N+1)
    for i in range(N):
        prefixW[i+1] = prefixW[i] + (A[i] * i)

    Q = int(input[ptr])
    ptr +=1

    queries = []
    for _ in range(Q):
        a_p = int(input[ptr])
        b_p = int(input[ptr+1])
        k_p = int(input[ptr+2])
        delta = int(input[ptr+3])
        ptr +=4
        queries.append( (a_p, b_p, k_p, delta) )

    X = []
    sumX = 0
    sumX2 = 0
    for q_idx in range(Q):
        a_p, b_p, k_p, delta = queries[q_idx]
        if q_idx ==0:
            a = a_p
            b = b_p
            k = k_p
        else:
            sum_prev = sumX
            a = (a_p + sum_prev) % N
            b = (b_p + 2*sum_prev) % N
            k_val = (k_p + (sum_prev)**2) % 150001
            k = k_val
        l = min(a, b)
        r = max(a, b) +1
        if r > N:
            r = N
        if l >= r:
            l =0
            r =1

        C_total = (prefixA[r] - prefixA[l]) + k * (r - l)
        sum_d = (prefixW[r] - prefixW[l]) + k * ( (r*(r-1)//2) - (l*(l-1)//2) )

        candidates = [delta]
        stack = [delta]
        visited = set()
        visited.add(delta)
        while stack:
            u = stack.pop()
            for v in edges[u]:
                if v not in visited:
                    visited.add(v)
                    candidates.append(v)
                    stack.append(v)
            if len(candidates) >= 20:
                break

        min_S = float('inf')
        best_v = -1
        for v in candidates:
            current = 0
            for w in range(l, r):
                current += (A[w] + k) * dist(v, w)
            if current < min_S:
                min_S = current
                best_v = v

        X.append(best_v)
        sumX += best_v
        sumX2 += best_v **2

    for x in X:
        print(x)

if __name__ == '__main__':
    main()
0