結果

問題 No.1442 I-wate Shortest Path Problem
ユーザー gew1fw
提出日時 2025-06-12 19:52:22
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 4,146 bytes
コンパイル時間 177 ms
コンパイル使用メモリ 81,860 KB
実行使用メモリ 213,016 KB
最終ジャッジ日時 2025-06-12 19:53:34
合計ジャッジ時間 30,144 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 12 WA * 13
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import heapq
from collections import deque

def main():
    sys.setrecursionlimit(1 << 25)
    input = sys.stdin.read().split()
    ptr = 0
    N, K = int(input[ptr]), int(input[ptr+1])
    ptr += 2

    # Read railway edges
    adj = [[] for _ in range(N+1)]
    for _ in range(N-1):
        A = int(input[ptr])
        B = int(input[ptr+1])
        C = int(input[ptr+2])
        adj[A].append((B, C))
        adj[B].append((A, C))
        ptr += 3

    # Compute LCA, depth, and distance from root
    root = 1
    LOG = 20
    up = [[-1]*(N+1) for _ in range(LOG)]
    depth = [0]*(N+1)
    distance = [0]*(N+1)  # distance from root

    # BFS to compute parent, depth, distance
    visited = [False]*(N+1)
    q = deque()
    q.append(root)
    visited[root] = True
    up[0][root] = -1

    while q:
        u = q.popleft()
        for v, c in adj[u]:
            if not visited[v]:
                visited[v] = True
                up[0][v] = u
                depth[v] = depth[u] + 1
                distance[v] = distance[u] + c
                q.append(v)

    # Precompute binary lifting table
    for k in range(1, LOG):
        for v in range(1, N+1):
            if up[k-1][v] != -1:
                up[k][v] = up[k-1][up[k-1][v]]
            else:
                up[k][v] = -1

    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        # Bring u up to depth of v
        for k in range(LOG-1, -1, -1):
            if depth[u] - (1 << k) >= depth[v]:
                u = up[k][u]
        if u == v:
            return u
        for k in range(LOG-1, -1, -1):
            if up[k][u] != -1 and up[k][u] != up[k][v]:
                u = up[k][u]
                v = up[k][v]
        return up[0][u]

    # Compute railway distance between u and v
    def railway_cost(u, v):
        ancestor = lca(u, v)
        return distance[u] + distance[v] - 2 * distance[ancestor]

    # Read airlines
    airlines = []
    sets = []
    for _ in range(K):
        M_i = int(input[ptr])
        P_i = int(input[ptr+1])
        ptr += 2
        X_list = list(map(int, input[ptr:ptr+M_i]))
        ptr += M_i
        airlines.append( (M_i, P_i, X_list) )
        sets.append( set(X_list) )

    # Precompute multi-source Dijkstra for each airline
    da = [ [float('inf')] * (N+1) for _ in range(K) ]
    for a in range(K):
        M_i, P_i, X_list = airlines[a]
        heap = []
        for x in X_list:
            da[a][x] = 0
            heapq.heappush(heap, (0, x))
        # Dijkstra
        while heap:
            d, u = heapq.heappop(heap)
            if d > da[a][u]:
                continue
            for v, c in adj[u]:
                if da[a][v] > d + c:
                    da[a][v] = d + c
                    heapq.heappush(heap, (da[a][v], v))

    # Precompute intersecting airlines
    intersect = [ [False]*K for _ in range(K) ]
    for a in range(K):
        for b in range(a+1, K):
            if len( sets[a] & sets[b] ) > 0:
                intersect[a][b] = True

    # Process queries
    Q = int(input[ptr])
    ptr +=1
    output = []
    for _ in range(Q):
        u = int(input[ptr])
        v = int(input[ptr+1])
        ptr +=2

        # Railway cost
        rail_cost = railway_cost(u, v)
        min_cost = rail_cost

        # Check each single airline
        for a in range(K):
            cost = da[a][u] + airlines[a][1] + da[a][v]
            if cost < min_cost:
                min_cost = cost

        # Check pairs of airlines that intersect
        for a in range(K):
            for b in range(a+1, K):
                if intersect[a][b]:
                    # Compute cost using a then b
                    cost1 = da[a][u] + airlines[a][1] + airlines[b][1] + da[b][v]
                    # Compute cost using b then a
                    cost2 = da[b][u] + airlines[b][1] + airlines[a][1] + da[a][v]
                    current_min = min(cost1, cost2)
                    if current_min < min_cost:
                        min_cost = current_min

        output.append(str(min_cost))
    
    print('\n'.join(output))

if __name__ == '__main__':
    main()
0