結果

問題 No.114 遠い未来
ユーザー toyuzukotoyuzuko
提出日時 2021-01-31 16:58:12
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,291 bytes
コンパイル時間 281 ms
コンパイル使用メモリ 87,048 KB
実行使用メモリ 80,408 KB
最終ジャッジ日時 2023-09-11 11:04:05
合計ジャッジ時間 8,938 ms
ジャッジサーバーID
(参考情報)
judge12 / judge13
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 393 ms
80,408 KB
testcase_01 TLE -
testcase_02 -- -
testcase_03 -- -
testcase_04 -- -
testcase_05 -- -
testcase_06 -- -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

from heapq import heappop, heappush

class Heap():
    def __init__(self):
        self.heap = []

    def pop(self):
        res = heappop(self.heap)
        return res

    def push(self, x):
        heappush(self.heap, x)

    def top(self):
        return self.heap[0]

    def size(self):
        return len(self.heap)

    def is_empty(self):
        return self.size() == 0

class DisjointSetUnion():
    def __init__(self, n):
        self.n = n
        self.par_size = [-1] * n

    def merge(self, a, b):
        x = self.leader(a)
        y = self.leader(b)
        if x == y: return x
        if -self.par_size[x] < -self.par_size[y]: x, y = y, x
        self.par_size[x] += self.par_size[y]
        self.par_size[y] = x
        return x

    def same(self, a, b):
        return self.leader(a) == self.leader(b)

    def leader(self, a):
        x = a
        while self.par_size[x] >= 0:
            x = self.par_size[x]
        while self.par_size[a] >= 0:
            self.par_size[a] = x
            a = self.par_size[a]
        return x

    def size(self, a):
        return -self.par_size[self.leader(a)]

INF = 10**18

class Graph():
    def __init__(self, n):
        self.n = n
        self.graph = [[] for _ in range(n)]
        self.edge = dict()

    def add_edge(self, u, v, c):
        if v < u: u, v = v, u
        self.graph[u].append(v)
        self.graph[v].append(u)
        self.edge[u * self.n + v] = c

    def minimum_steiner_tree(self, terminal):
        t = len(terminal)
        if t <= 1: return 0
        dp = [[INF] * self.n for _ in range(1 << t)]
        for i, v in enumerate(terminal):
            dp[1 << i][v] = 0
        for bit in range(1, 1 << t):
            for v in range(self.n):
                subset = bit
                while subset:
                    dp[bit][v] = min(dp[bit][v], dp[subset][v] + dp[bit ^ subset][v])
                    subset = (subset - 1) & bit
            if bit == (1 << t) - 1: break
            heap = Heap()
            for v in range(self.n):
                heap.push((dp[bit][v], v))
            while not heap.is_empty():
                d, v = heap.pop()
                if dp[bit][v] < d: continue
                for adj in self.graph[v]:
                    if v < adj:
                        c = self.edge[v * self.n + adj]
                    else:
                        c = self.edge[adj * self.n + v]
                    if dp[bit][adj] > dp[bit][v] + c:
                        dp[bit][adj] = dp[bit][v] + c
                        heap.push((dp[bit][adj], adj))
        return dp[-1][terminal[0]]

import sys
input = sys.stdin.buffer.readline

N, M, T = map(int, input().split())

if T < 15:
    g = Graph(N)

    for _ in range(M):
        a, b, c = map(int, input().split())
        g.add_edge(a - 1, b - 1, c)

    V = [int(input()) - 1 for _ in range(T)]

    print(g.minimum_steiner_tree(V))

else:
    edge = [tuple(map(int, input().split())) for _ in range(M)]
    edge.sort(key=lambda x: x[2])
    V = dict()

    for i in range(T):
        v = int(input())
        V[v - 1] = i

    W = []

    for i in range(N):
        if not i in V:
            W.append(i)

    res = 10**18

    for bit in range(1 << (N - T)):
        p = 0
        w = dict()
        for i in range(N):
            if (bit >> i) & 1:
                w[W[i]] = p + T
                p += 1

        dsu = DisjointSetUnion(T + p)
        mst = 0

        for a, b, c in edge:
            if a - 1 in V:
                if b - 1 in V:
                    if dsu.same(V[a - 1], V[b - 1]): continue
                    mst += c
                    dsu.merge(V[a - 1], V[b - 1])
                elif b - 1 in w:
                    if dsu.same(V[a - 1], w[b - 1]): continue
                    mst += c
                    dsu.merge(V[a - 1], w[b - 1])
            elif a - 1 in w:
                if b - 1 in V:
                    if dsu.same(w[a - 1], V[b - 1]): continue
                    mst += c
                    dsu.merge(w[a - 1], V[b - 1])
                elif b - 1 in w:
                    if dsu.same(w[a - 1], w[b - 1]): continue
                    mst += c
                    dsu.merge(w[a - 1], w[b - 1])

        if dsu.size(0) == T + p:
            res = min(res, mst)

    print(res)
0