結果

問題 No.1442 I-wate Shortest Path Problem
ユーザー lam6er
提出日時 2025-04-15 22:35:31
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 4,085 bytes
コンパイル時間 162 ms
コンパイル使用メモリ 81,944 KB
実行使用メモリ 194,556 KB
最終ジャッジ日時 2025-04-15 22:37:36
合計ジャッジ時間 24,095 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 12 WA * 13
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import heapq
from math import log2, floor
from collections import deque

sys.setrecursionlimit(1 << 25)

def main():
    input = sys.stdin.read().split()
    ptr = 0

    N, K = map(int, input[ptr:ptr+2])
    ptr += 2

    # Build adjacency list for the tree
    adj = [[] for _ in range(N+1)]  # 1-based indexing
    for _ in range(N-1):
        A = int(input[ptr])
        B = int(input[ptr+1])
        C = int(input[ptr+2])
        adj[A].append((B, C))
        adj[B].append((A, C))
        ptr += 3

    # Precompute parent, depth, and distance from root using BFS
    root = 1
    parent = [0] * (N + 1)
    depth = [0] * (N + 1)
    dist_from_root = [0] * (N + 1)
    visited = [False] * (N + 1)
    q = deque([root])
    visited[root] = True

    while q:
        u = q.popleft()
        for v, c in adj[u]:
            if not visited[v] and v != parent[u]:
                visited[v] = True
                parent[v] = u
                depth[v] = depth[u] + 1
                dist_from_root[v] = dist_from_root[u] + c
                q.append(v)

    # Precompute binary lifting table for LCA
    log_max = floor(log2(N)) if N else 0
    ancestors = [[0] * (N + 1) for _ in range(log_max + 1)]
    ancestors[0] = parent[:]
    for k in range(1, log_max + 1):
        for v in range(1, N + 1):
            ancestors[k][v] = ancestors[k-1][ancestors[k-1][v]]

    def get_lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        # Bring u to the depth of v
        for k in range(log_max, -1, -1):
            if depth[u] - (1 << k) >= depth[v]:
                u = ancestors[k][u]
        if u == v:
            return u
        for k in range(log_max, -1, -1):
            if ancestors[k][u] != ancestors[k][v]:
                u = ancestors[k][u]
                v = ancestors[k][v]
        return ancestors[0][u]

    # Read airlines and precompute their distances
    airlines = []
    for _ in range(K):
        M_i = int(input[ptr])
        P_i = int(input[ptr+1])
        ptr += 2
        X_i = list(map(int, input[ptr:ptr+M_i]))
        ptr += M_i

        # Compute shortest paths using Dijkstra's algorithm
        dist = [float('inf')] * (N + 1)
        heap = []
        for x in X_i:
            dist[x] = 0
            heapq.heappush(heap, (0, x))
        while heap:
            current_dist, u = heapq.heappop(heap)
            if current_dist > dist[u]:
                continue
            for v, c in adj[u]:
                new_dist = current_dist + c
                if new_dist < dist[v]:
                    dist[v] = new_dist
                    heapq.heappush(heap, (new_dist, v))
        airlines.append((P_i, dist))

    # Process queries
    Q = int(input[ptr])
    ptr += 1
    for _ in range(Q):
        U = int(input[ptr])
        V = int(input[ptr+1])
        ptr += 2

        # Calculate railway cost
        lca_node = get_lca(U, V)
        rail_cost = dist_from_root[U] + dist_from_root[V] - 2 * dist_from_root[lca_node]

        if K == 0:
            print(rail_cost)
            continue

        # Collect distances for each airline
        d_u = []
        d_v = []
        P_list = []
        for P_i, dist in airlines:
            d_u.append(dist[U])
            d_v.append(dist[V])
            P_list.append(P_i)

        # Calculate minimum cost using single airline
        min_single = float('inf')
        for i in range(len(P_list)):
            cost = P_list[i] + d_u[i] + d_v[i]
            if cost < min_single:
                min_single = cost

        # Calculate minimum cost using pairs of airlines
        min_pair = float('inf')
        for i in range(len(P_list)):
            for j in range(i + 1, len(P_list)):
                cost = P_list[i] + P_list[j] + min(d_u[i], d_u[j]) + min(d_v[i], d_v[j])
                if cost < min_pair:
                    min_pair = cost

        # Determine the minimum cost between all options
        airline_cost = min(min_single, min_pair)
        answer = min(rail_cost, airline_cost)
        print(answer)

if __name__ == '__main__':
    main()
0