結果

問題 No.1442 I-wate Shortest Path Problem
ユーザー lam6er
提出日時 2025-04-16 15:56:22
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 4,106 bytes
コンパイル時間 189 ms
コンパイル使用メモリ 82,288 KB
実行使用メモリ 193,536 KB
最終ジャッジ日時 2025-04-16 15:59:28
合計ジャッジ時間 24,433 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 18 WA * 7
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import heapq
from math import log2, ceil
sys.setrecursionlimit(1 << 25)

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, K = int(input[ptr]), int(input[ptr+1])
    ptr +=2

    # Read tree edges
    adj = [[] for _ in range(N+1)]
    for _ in range(N-1):
        A = int(input[ptr])
        B = int(input[ptr+1])
        C = int(input[ptr+2])
        adj[A].append((B, C))
        adj[B].append((A, C))
        ptr +=3

    # Precompute LCA and prefix sums
    LOG = ceil(log2(N)) if N > 1 else 1
    parent = [[-1]*(N+1) for _ in range(LOG)]
    depth = [0]*(N+1)
    sum_dist = [0]*(N+1)
    # BFS to compute parent, depth, sum_dist
    from collections import deque
    root = 1
    q = deque()
    q.append(root)
    parent[0][root] = -1
    depth[root] = 0
    sum_dist[root] = 0
    visited = [False]*(N+1)
    visited[root] = True
    while q:
        u = q.popleft()
        for v, c in adj[u]:
            if not visited[v]:
                visited[v] = True
                parent[0][v] = u
                depth[v] = depth[u] +1
                sum_dist[v] = sum_dist[u] + c
                q.append(v)
    # Build binary lifting table
    for k in range(1, LOG):
        for v in range(1, N+1):
            if parent[k-1][v] != -1:
                parent[k][v] = parent[k-1][parent[k-1][v]]
            else:
                parent[k][v] = -1

    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        # Bring u to the same depth as v
        for k in range(LOG-1, -1, -1):
            if depth[u] - (1 << k) >= depth[v]:
                u = parent[k][u]
        if u == v:
            return u
        for k in range(LOG-1, -1, -1):
            if parent[k][u] != -1 and parent[k][u] != parent[k][v]:
                u = parent[k][u]
                v = parent[k][v]
        return parent[0][u]

    def distance(u, v):
        ancestor = lca(u, v)
        return sum_dist[u] + sum_dist[v] - 2 * sum_dist[ancestor]

    # Read airlines
    airlines = []
    for _ in range(K):
        M = int(input[ptr])
        P = int(input[ptr+1])
        X = list(map(int, input[ptr+2:ptr+2+M]))
        ptr += 2 + M
        airlines.append((M, P, X))

    # Precompute d_i[u] for each airline
    d = []
    for idx in range(K):
        M_i, P_i, X_i = airlines[idx]
        INF = float('inf')
        dist = [INF] * (N+1)
        heap = []
        for x in X_i:
            dist[x] = 0
            heapq.heappush(heap, (0, x))
        while heap:
            current_dist, u = heapq.heappop(heap)
            if current_dist > dist[u]:
                continue
            for v, c in adj[u]:
                if dist[v] > dist[u] + c:
                    dist[v] = dist[u] + c
                    heapq.heappush(heap, (dist[v], v))
        d.append(dist)

    # Precompute d_ij for all pairs (i, j)
    K_air = K
    dij = [[float('inf')] * K_air for _ in range(K_air)]
    for i in range(K_air):
        for j in range(K_air):
            if i == j:
                continue
            M_j, P_j, X_j = airlines[j]
            min_dist = float('inf')
            for y in X_j:
                if d[i][y] < min_dist:
                    min_dist = d[i][y]
            dij[i][j] = min_dist

    # Read queries
    Q = int(input[ptr])
    ptr +=1
    for _ in range(Q):
        U = int(input[ptr])
        V = int(input[ptr+1])
        ptr +=2
        direct = distance(U, V)
        min_cost = direct
        # Check all single airlines
        for i in range(K_air):
            P_i = airlines[i][1]
            cost = P_i + d[i][U] + d[i][V]
            if cost < min_cost:
                min_cost = cost
        # Check all pairs of airlines (i, j)
        for i in range(K_air):
            for j in range(K_air):
                if i == j:
                    continue
                P_i = airlines[i][1]
                P_j = airlines[j][1]
                cost = P_i + P_j + d[i][U] + d[j][V] + dij[i][j]
                if cost < min_cost:
                    min_cost = cost
        print(min_cost)

if __name__ == '__main__':
    main()
0