結果

問題 No.386 貪欲な領主
ユーザー lam6er
提出日時 2025-03-20 20:27:28
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 547 ms / 2,000 ms
コード長 2,077 bytes
コンパイル時間 145 ms
コンパイル使用メモリ 82,316 KB
実行使用メモリ 177,620 KB
最終ジャッジ日時 2025-03-20 20:28:58
合計ジャッジ時間 4,611 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque

def main():
    input = sys.stdin.read().split()
    ptr = 0

    N = int(input[ptr])
    ptr += 1

    adj = [[] for _ in range(N)]
    for _ in range(N-1):
        a = int(input[ptr])
        b = int(input[ptr+1])
        adj[a].append(b)
        adj[b].append(a)
        ptr += 2

    U = []
    for _ in range(N):
        U.append(int(input[ptr]))
        ptr += 1

    M = int(input[ptr])
    ptr += 1

    queries = []
    for _ in range(M):
        a = int(input[ptr])
        b = int(input[ptr+1])
        c = int(input[ptr+2])
        queries.append((a, b, c))
        ptr += 3

    # Initialize BFS for parent, depth, sum_u
    parent = [-1] * N
    depth = [0] * N
    sum_u = [0] * N
    root = 0
    sum_u[root] = U[root]
    q = deque([root])
    parent[root] = -1

    while q:
        u = q.popleft()
        for v in adj[u]:
            if parent[u] != v and parent[v] == -1:
                parent[v] = u
                depth[v] = depth[u] + 1
                sum_u[v] = sum_u[u] + U[v]
                q.append(v)

    # Prepare binary lifting table for LCA
    max_level = 20
    up = [[-1] * N for _ in range(max_level)]
    up[0] = parent[:]
    for k in range(1, max_level):
        for u in range(N):
            if up[k-1][u] != -1:
                up[k][u] = up[k-1][up[k-1][u]]
            else:
                up[k][u] = -1

    def get_lca(a, b):
        if depth[a] < depth[b]:
            a, b = b, a
        # Bring a to the depth of b
        for k in range(max_level-1, -1, -1):
            if depth[a] - (1 << k) >= depth[b]:
                a = up[k][a]
        if a == b:
            return a
        # Move both up to find LCA
        for k in range(max_level-1, -1, -1):
            if up[k][a] != up[k][b]:
                a = up[k][a]
                b = up[k][b]
        return parent[a]

    total = 0
    for a, b, c in queries:
        l = get_lca(a, b)
        su = sum_u[a] + sum_u[b] - 2 * sum_u[l] + U[l]
        total += su * c
    print(total)

if __name__ == '__main__':
    main()
0