結果

問題 No.386 貪欲な領主
ユーザー lam6er
提出日時 2025-03-20 20:25:01
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 478 ms / 2,000 ms
コード長 1,895 bytes
コンパイル時間 156 ms
コンパイル使用メモリ 82,464 KB
実行使用メモリ 109,736 KB
最終ジャッジ日時 2025-03-20 20:26:21
合計ジャッジ時間 3,897 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 16
権限があれば一括ダウンロードができます

ソースコード

diff #

from collections import deque
import sys

def main():
    sys.setrecursionlimit(1 << 25)
    N = int(sys.stdin.readline())
    edges = [[] for _ in range(N)]
    for _ in range(N-1):
        a, b = map(int, sys.stdin.readline().split())
        edges[a].append(b)
        edges[b].append(a)
    
    U = []
    for _ in range(N):
        U.append(int(sys.stdin.readline()))
    
    # BFS to build parent, depth, S
    parent = [-1] * N
    depth = [0] * N
    S = [0] * N
    root = 0
    q = deque()
    q.append(root)
    parent[root] = -1
    S[root] = U[root]
    
    while q:
        u = q.popleft()
        for v in edges[u]:
            if parent[v] == -1 and v != root:
                parent[v] = u
                depth[v] = depth[u] + 1
                S[v] = S[u] + U[v]
                q.append(v)
    
    # Preprocess for LCA with binary lifting
    log_max = 20
    table = [[-1] * N for _ in range(log_max)]
    table[0] = parent
    
    for k in range(1, log_max):
        for v in range(N):
            if table[k-1][v] != -1:
                table[k][v] = table[k-1][table[k-1][v]]
    
    def get_lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        # Bring u to the depth of v
        for k in reversed(range(log_max)):
            if depth[u] - (1 << k) >= depth[v]:
                u = table[k][u]
        if u == v:
            return u
        # Move both upwards
        for k in reversed(range(log_max)):
            if table[k][u] != table[k][v]:
                u = table[k][u]
                v = table[k][v]
        return table[0][u]
    
    M = int(sys.stdin.readline())
    ans = 0
    for _ in range(M):
        A, B, C = map(int, sys.stdin.readline().split())
        lca_node = get_lca(A, B)
        sum_u = S[A] + S[B] - 2 * S[lca_node] + U[lca_node]
        ans += sum_u * C
    print(ans)

if __name__ == "__main__":
    main()
0