結果

問題 No.1301 Strange Graph Shortest Path
ユーザー lam6er
提出日時 2025-03-26 15:54:22
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,851 bytes
コンパイル時間 186 ms
コンパイル使用メモリ 82,232 KB
実行使用メモリ 248,080 KB
最終ジャッジ日時 2025-03-26 15:55:56
合計ジャッジ時間 37,378 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 28 WA * 5
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import heapq
from collections import defaultdict

def main():
    input = sys.stdin.read().split()
    idx = 0
    N, M = int(input[idx]), int(input[idx+1])
    idx +=2

    edges = [[] for _ in range(N+1)]
    edge_info = {}
    for _ in range(M):
        u = int(input[idx])
        v = int(input[idx+1])
        c = int(input[idx+2])
        d = int(input[idx+3])
        idx +=4
        edges[u].append((v, c, d))
        edges[v].append((u, c, d))
        if (u, v) not in edge_info:
            edge_info[(u, v)] = (c, d)
        if (v, u) not in edge_info:
            edge_info[(v, u)] = (c, d)

    # Function to run Dijkstra and return the shortest path cost and the edges used
    def dijkstra(start, end, graph, track_edges=False):
        dist = [float('inf')] * (N+1)
        dist[start] = 0
        parent = [[] for _ in range(N+1)]
        heap = [(0, start)]
        heapq.heapify(heap)
        while heap:
            current_dist, u = heapq.heappop(heap)
            if u == end:
                break
            if current_dist > dist[u]:
                continue
            for (v, c, d) in graph[u]:
                if dist[v] > current_dist + c:
                    dist[v] = current_dist + c
                    parent[v] = [(u, (u, v, c, d))]
                    heapq.heappush(heap, (dist[v], v))
                elif dist[v] == current_dist + c:
                    parent[v].append((u, (u, v, c, d)))

        if dist[end] == float('inf'):
            return (float('inf'), set()) if track_edges else float('inf')

        if not track_edges:
            return dist[end]

        # Backtrack to find edges used in the shortest path
        visited = set()
        edge_set = set()
        q = [end]
        while q:
            v = q.pop()
            if v in visited:
                continue
            visited.add(v)
            for (u, e) in parent[v]:
                edge = (min(u, e[1]), max(u, e[1]))
                edge_set.add(edge)
                q.append(u)
        return (dist[end], edge_set)

    # Get the shortest path S1 and the edges used
    s1, edges_s1 = dijkstra(1, N, edges, track_edges=True)
    if s1 == float('inf'):
        print(-1)
        return

    # Function to create a new graph by removing edges in edges_s1
    def remove_edges(original_edges, edges_to_remove):
        new_edges = [[] for _ in range(N+1)]
        for u in range(1, N+1):
            for (v, c, d) in original_edges[u]:
                edge = (min(u, v), max(u, v))
                if edge not in edges_to_remove:
                    new_edges[u].append((v, c, d))
        return new_edges

    # Compute candidate1: remove edges in s1 and find path from N to 1
    graph_removed = remove_edges(edges, edges_s1)
    candidate1 = dijkstra(N, 1, graph_removed, track_edges=False)

    # Function to create a new graph where edges in edges_s1 have their cost changed to d
    def modify_edges(original_edges, edges_to_modify):
        new_edges = [[] for _ in range(N+1)]
        for u in range(1, N+1):
            for (v, c, d) in original_edges[u]:
                edge = (min(u, v), max(u, v))
                if edge in edges_to_modify:
                    new_edges[u].append((v, d, d))
                else:
                    new_edges[u].append((v, c, d))
        return new_edges

    # Compute candidate2: modify edges in s1 to d and find path from N to 1
    graph_modified = modify_edges(edges, edges_s1)
    candidate2 = dijkstra(N, 1, graph_modified, track_edges=False)

    # Calculate possible answers
    min_total = float('inf')
    if candidate1 != float('inf'):
        min_total = s1 + candidate1
    if candidate2 != float('inf'):
        min_total = min(min_total, s1 + candidate2)

    print(min_total if min_total != float('inf') else -1)

if __name__ == '__main__':
    main()
0