結果

問題 No.1442 I-wate Shortest Path Problem
ユーザー gew1fw
提出日時 2025-06-12 14:09:13
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 4,015 bytes
コンパイル時間 325 ms
コンパイル使用メモリ 82,604 KB
実行使用メモリ 168,104 KB
最終ジャッジ日時 2025-06-12 14:09:40
合計ジャッジ時間 19,450 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 7 WA * 8 TLE * 1 -- * 9
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import heapq
from math import log2, ceil

def main():
    sys.setrecursionlimit(1 << 25)
    input = sys.stdin.read().split()
    ptr = 0

    N, K = int(input[ptr]), int(input[ptr+1])
    ptr +=2

    # Build adjacency list
    adj = [[] for _ in range(N+1)]
    for _ in range(N-1):
        a = int(input[ptr])
        b = int(input[ptr+1])
        c = int(input[ptr+2])
        adj[a].append((b, c))
        adj[b].append((a, c))
        ptr +=3

    # Precompute LCA and distance from root
    max_log = ceil(log2(N)) if N > 1 else 1
    parent = [[-1]*(N+1) for _ in range(max_log)]
    depth = [0]*(N+1)
    dist_root = [0]*(N+1)

    # Iterative DFS to fill parent[0], depth, dist_root
    stack = [(1, -1, 0, 0)]  # (node, parent, depth, distance)
    while stack:
        u, p, d, dist = stack.pop()
        parent[0][u] = p
        depth[u] = d
        dist_root[u] = dist
        for v, cost in adj[u]:
            if v != p:
                stack.append((v, u, d+1, dist + cost))

    # Fill parent table for binary lifting
    for k in range(1, max_log):
        for u in range(1, N+1):
            if parent[k-1][u] != -1:
                parent[k][u] = parent[k-1][parent[k-1][u]]
            else:
                parent[k][u] = -1

    # LCA function
    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        # Bring u to the depth of v
        for k in reversed(range(max_log)):
            if depth[u] - (1 << k) >= depth[v]:
                u = parent[k][u]
        if u == v:
            return u
        for k in reversed(range(max_log)):
            if parent[k][u] != -1 and parent[k][u] != parent[k][v]:
                u = parent[k][u]
                v = parent[k][v]
        return parent[0][u]

    # Function to compute distance between u and v
    def get_distance(u, v):
        ancestor = lca(u, v)
        return dist_root[u] + dist_root[v] - 2 * dist_root[ancestor]

    # Read airlines
    airlines = []
    P = []
    for _ in range(K):
        M_i = int(input[ptr])
        P_i = int(input[ptr+1])
        ptr +=2
        X = list(map(int, input[ptr:ptr+M_i]))
        ptr += M_i
        airlines.append(X)
        P.append(P_i)

    # For each airline, compute dist_airline[i][x]
    dist_airline = []
    for i in range(K):
        sources = airlines[i]
        dist = [float('inf')] * (N+1)
        heap = []
        for s in sources:
            dist[s] = 0
            heapq.heappush(heap, (0, s))
        while heap:
            d, u = heapq.heappop(heap)
            if d > dist[u]:
                continue
            for v, cost in adj[u]:
                if dist[v] > d + cost:
                    dist[v] = d + cost
                    heapq.heappush(heap, (dist[v], v))
        dist_airline.append(dist)

    # Precompute sum_P for all masks
    sum_P = [0]*(1 << K)
    for mask in range(1, 1 << K):
        lb = mask & -mask
        i = (lb).bit_length() -1
        sum_P[mask] = sum_P[mask ^ lb] + P[i]

    # Precompute airlines_in_mask
    airlines_in_mask = [[] for _ in range(1 << K)]
    for mask in range(1 << K):
        for i in range(K):
            if mask & (1 << i):
                airlines_in_mask[mask].append(i)

    # Read queries
    Q = int(input[ptr])
    ptr +=1
    for _ in range(Q):
        U = int(input[ptr])
        V = int(input[ptr+1])
        ptr +=2

        rail_cost = get_distance(U, V)
        if K ==0:
            print(rail_cost)
            continue

        a = [dist_airline[i][U] for i in range(K)]
        b = [dist_airline[i][V] for i in range(K)]

        min_cost = rail_cost

        for mask in range(1, 1 << K):
            sum_p = sum_P[mask]
            min_a_val = min(a[i] for i in airlines_in_mask[mask])
            min_b_val = min(b[i] for i in airlines_in_mask[mask])
            current_cost = sum_p + min_a_val + min_b_val
            if current_cost < min_cost:
                min_cost = current_cost

        print(min_cost)

if __name__ == '__main__':
    main()
0