結果

問題 No.114 遠い未来
ユーザー Mao-betaMao-beta
提出日時 2024-03-08 12:32:24
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,923 bytes
コンパイル時間 271 ms
コンパイル使用メモリ 82,304 KB
実行使用メモリ 84,096 KB
最終ジャッジ日時 2024-09-29 18:47:18
合計ジャッジ時間 7,211 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 284 ms
77,184 KB
testcase_01 TLE -
testcase_02 -- -
testcase_03 -- -
testcase_04 -- -
testcase_05 -- -
testcase_06 -- -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import math
import bisect
from heapq import heapify, heappop, heappush
from collections import deque, defaultdict, Counter
from functools import lru_cache
from itertools import accumulate, combinations, permutations, product

sys.setrecursionlimit(1000000)
MOD = 10 ** 9 + 7
MOD99 = 998244353

input = lambda: sys.stdin.readline().strip()
NI = lambda: int(input())
NMI = lambda: map(int, input().split())
NLI = lambda: list(NMI())
SI = lambda: input()
SMI = lambda: input().split()
SLI = lambda: list(SMI())
EI = lambda m: [NLI() for _ in range(m)]


def main():
    N, M, T = NMI()
    ABC = EI(M)
    ABC = [[x-1, y-1, z] for x, y, z in ABC]
    V = [NI() for _ in range(T)]
    V = [x-1 for x in V]
    INF = 10 ** 15

    if T <= 14:
        # 最小シュタイナー木
        # ワーシャルフロイド
        D = [[INF]*N for _ in range(N)]
        for i in range(N):
            D[i][i] = 0
        for a, b, c in ABC:
            D[a][b] = c
            D[b][a] = c
        for k in range(N):
            for i in range(N):
                for j in range(N):
                    D[i][j] = min(D[i][j], D[i][k] + D[k][j])
        # dp[i][S]: iを端点に持ち、Vの部分集合S(T-bit)を含むシュタイナー木の重み
        dp = [[INF] * (1<<T) for _ in range(N)]
        # 各vについて、端点がiのときの初期値
        for vi in range(T):
            for i in range(N):
                dp[i][1<<vi] = D[i][V[vi]]
            dp[V[vi]][1<<vi] = 0
        for i in range(N):
            dp[i][0] = 0

        def gen_subset(S):
            s = (S-1) & S
            while s > 0:
                yield s
                s = (s-1) & S

        # O(3^T)の部分集合DP
        # トータルでO(N*3^T + N^2*2^T)
        for S in range(1, 1<<T):
            for i in range(N):
                for E in gen_subset(S):
                    dp[i][S] = min(dp[i][S], dp[i][S-E] + dp[i][E])
            for i in range(N):
                for j in range(N):
                    dp[i][S] = min(dp[i][S], dp[j][S] + D[i][j])

        ans = INF
        for i in range(N):
            for S in range(1<<T):
                ans = min(ans, dp[i][S] + dp[i][(1<<T)-1-S])
        print(ans)

    else:
        # N-T <= 20
        # 使わない頂点の集合を全探索してMST

        class UnionFind:
            def __init__(self, n):
                # 親要素のノード番号を格納 xが根のとき-(サイズ)を格納
                self.par = [-1 for i in range(n)]
                self.n = n
                self.roots = set(range(n))
                self.group_num = n

            def find(self, x):
                # 根ならその番号を返す
                if self.par[x] < 0:
                    return x
                else:
                    # 親の親は親
                    self.par[x] = self.find(self.par[x])
                    return self.par[x]

            def is_same(self, x, y):
                # 根が同じならTrue
                return self.find(x) == self.find(y)

            def unite(self, x, y):
                x = self.find(x)
                y = self.find(y)
                if x == y: return

                # 木のサイズを比較し、小さいほうから大きいほうへつなぐ
                if self.par[x] > self.par[y]:
                    x, y = y, x

                self.group_num -= 1
                self.roots.discard(y)
                assert self.group_num == len(self.roots)

                self.par[x] += self.par[y]
                self.par[y] = x

            def size(self, x):
                return -self.par[self.find(x)]

            def get_roots(self):
                return self.roots

            def group_count(self):
                return len(self.roots)


        def MST(N, edges, target):
            """
            要UnionFind
            N頂点のうち、target[i]==1の点のみの最小全域木の長さ
            edges = [[u, v, cost], ....] (0-index) (sort済み)
            """
            uf = UnionFind(N)
            # edges.sort(key=lambda x: x[-1])
            res = 0
            for a, b, c in edges:
                if target[a] == 0 or target[b] == 0:
                    continue
                if uf.is_same(a, b):
                    continue
                else:
                    res += c
                    uf.unite(a, b)
            return res

        ABC.sort(key=lambda x: x[-1])
        Vbar = [i for i in range(N) if i not in V]
        Vbn = len(Vbar)
        target = [1] * N

        ans = INF
        for case in range(1<<Vbn):
            for i in range(Vbn):
                if (case >> i) & 1:
                    target[Vbar[i]] = 1
                else:
                    target[Vbar[i]] = 0

            res = MST(N, ABC, target)
            ans = min(ans, res)

        print(ans)


if __name__ == "__main__":
    main()
0