結果
問題 | No.19 ステージの選択 |
ユーザー | rpy3cpp |
提出日時 | 2015-08-08 02:27:30 |
言語 | Python3 (3.12.2 + numpy 1.26.4 + scipy 1.12.0) |
結果 |
AC
|
実行時間 | 20 ms / 5,000 ms |
コード長 | 2,071 bytes |
コンパイル時間 | 80 ms |
コンパイル使用メモリ | 10,928 KB |
実行使用メモリ | 8,872 KB |
最終ジャッジ日時 | 2023-08-24 23:29:14 |
合計ジャッジ時間 | 1,627 ms |
ジャッジサーバーID (参考情報) |
judge15 / judge14 |
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 19 ms
8,668 KB |
testcase_01 | AC | 19 ms
8,660 KB |
testcase_02 | AC | 19 ms
8,860 KB |
testcase_03 | AC | 19 ms
8,728 KB |
testcase_04 | AC | 19 ms
8,800 KB |
testcase_05 | AC | 19 ms
8,840 KB |
testcase_06 | AC | 19 ms
8,872 KB |
testcase_07 | AC | 19 ms
8,784 KB |
testcase_08 | AC | 19 ms
8,760 KB |
testcase_09 | AC | 20 ms
8,868 KB |
testcase_10 | AC | 19 ms
8,668 KB |
testcase_11 | AC | 19 ms
8,664 KB |
testcase_12 | AC | 19 ms
8,688 KB |
testcase_13 | AC | 19 ms
8,792 KB |
testcase_14 | AC | 19 ms
8,820 KB |
testcase_15 | AC | 19 ms
8,816 KB |
testcase_16 | AC | 19 ms
8,760 KB |
testcase_17 | AC | 19 ms
8,688 KB |
testcase_18 | AC | 20 ms
8,868 KB |
testcase_19 | AC | 19 ms
8,732 KB |
testcase_20 | AC | 19 ms
8,772 KB |
testcase_21 | AC | 19 ms
8,728 KB |
testcase_22 | AC | 19 ms
8,756 KB |
testcase_23 | AC | 19 ms
8,824 KB |
ソースコード
import collections class DisjointSet(object): def __init__(self, n): self.parent = list(range(n)) self.rank = [0] * n self.num = n # number of disjoint sets def union(self, x, y): self._link(self.find_set(x), self.find_set(y)) def _link(self, x, y): if x == y: return self.num -= 1 if self.rank[x] > self.rank[y]: self.parent[y] = x else: self.parent[x] = y if self.rank[x] == self.rank[y]: self.rank[y] += 1 def find_set(self, x): xp = self.parent[x] if xp != x: self.parent[x] = self.find_set(xp) return self.parent[x] def read_data(): N = int(input()) Ls = [] dst2src = [] src2dst = [set() for n in range(N)] for n in range(N): l, s = map(int, input().split()) Ls.append(l) dst2src.append(s - 1) src2dst[s - 1].add(n) return N, Ls, dst2src, src2dst def solve(N, Ls, dst2src, src2dst): groups = decompose(N, dst2src) total = 0 for group in groups: total += evaluate_difficulty(group, Ls, dst2src, src2dst) return total/2 def decompose(N, dst2src): djs = DisjointSet(N) for d, s in enumerate(dst2src): djs.union(d, s) groups = collections.defaultdict(set) for i in range(N): groups[djs.find_set(i)].add(i) return list(groups.values()) def evaluate_difficulty(group, Ls, dst2src, src2dst): score = sum(Ls[i] for i in group) trim_leaf(group, dst2src, src2dst) score += min(Ls[i] for i in group) return score def trim_leaf(group, dst2src, src2dst): leaf = [i for i in group if not src2dst[i]] while leaf: dst = leaf[-1] del leaf[-1] group.remove(dst) src = dst2src[dst] src2dst[src].remove(dst) if not src2dst[src]: leaf.append(src) if __name__ == '__main__': N, Ls, dst2src, src2dst = read_data() print('{:.1f}'.format(solve(N, Ls, dst2src, src2dst)))