結果
問題 | No.1833 Subway Planning |
ユーザー | neterukun |
提出日時 | 2022-02-04 23:59:05 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 2,744 ms / 4,000 ms |
コード長 | 10,638 bytes |
コンパイル時間 | 262 ms |
コンパイル使用メモリ | 82,048 KB |
実行使用メモリ | 376,668 KB |
最終ジャッジ日時 | 2024-06-11 13:09:08 |
合計ジャッジ時間 | 34,125 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 43 ms
54,656 KB |
testcase_01 | AC | 42 ms
54,912 KB |
testcase_02 | AC | 42 ms
55,296 KB |
testcase_03 | AC | 43 ms
54,912 KB |
testcase_04 | AC | 1,333 ms
343,820 KB |
testcase_05 | AC | 1,368 ms
346,224 KB |
testcase_06 | AC | 1,921 ms
328,888 KB |
testcase_07 | AC | 2,690 ms
346,916 KB |
testcase_08 | AC | 2,744 ms
354,188 KB |
testcase_09 | AC | 894 ms
225,900 KB |
testcase_10 | AC | 1,953 ms
376,668 KB |
testcase_11 | AC | 879 ms
230,072 KB |
testcase_12 | AC | 891 ms
231,096 KB |
testcase_13 | AC | 1,870 ms
360,548 KB |
testcase_14 | AC | 1,287 ms
362,280 KB |
testcase_15 | AC | 1,872 ms
358,612 KB |
testcase_16 | AC | 932 ms
214,888 KB |
testcase_17 | AC | 2,469 ms
347,984 KB |
testcase_18 | AC | 2,251 ms
333,088 KB |
testcase_19 | AC | 1,084 ms
213,188 KB |
testcase_20 | AC | 840 ms
191,688 KB |
testcase_21 | AC | 2,412 ms
332,968 KB |
testcase_22 | AC | 2,180 ms
345,052 KB |
testcase_23 | AC | 42 ms
54,720 KB |
testcase_24 | AC | 40 ms
54,892 KB |
testcase_25 | AC | 41 ms
54,636 KB |
ソースコード
import sys input = sys.stdin.buffer.readline class UnionFind: def __init__(self, n): self.parent = [-1] * n self.n = n self.cnt = n def root(self, x): if self.parent[x] < 0: return x else: self.parent[x] = self.root(self.parent[x]) return self.parent[x] def merge(self, x, y): x = self.root(x) y = self.root(y) if x == y: return False if self.parent[x] > self.parent[y]: x, y = y, x self.parent[x] += self.parent[y] self.parent[y] = x self.cnt -= 1 return True def same(self, x, y): return self.root(x) == self.root(y) def size(self, x): return -self.parent[self.root(x)] def count(self): return self.cnt def groups(self): res = [[] for _ in range(self.n)] for i in range(self.n): res[self.root(i)].append(i) return [group for group in res if group] def rerooting(n, edges, unit, merge, addnode): tree = [[] for i in range(n)] idxs = [[] for i in range(n)] for u, v in edges: idxs[u].append(len(tree[v])) idxs[v].append(len(tree[u])) tree[u].append(v) tree[v].append(u) sub = [[unit] * len(tree[v]) for v in range(n)] noderes = [unit] * n # topological sort tp_order = [] par = [-1] * n for root in range(n): if par[root] != -1: continue stack = [root] while stack: v = stack.pop() tp_order.append(v) for nxt_v in tree[v]: if nxt_v == par[v]: continue par[nxt_v] = v stack.append(nxt_v) # tree DP for v in reversed(tp_order[1:]): res = unit par_idx = -1 for idx, nxt_v in enumerate(tree[v]): if nxt_v == par[v]: par_idx = idx continue res = merge(res, sub[v][idx]) if par_idx != -1: sub[par[v]][idxs[v][par_idx]] = addnode(res, v) # rerooting DP for v in tp_order: acc_back = [unit] * len(tree[v]) for i in reversed(range(1, len(acc_back))): acc_back[i - 1] = merge(sub[v][i], acc_back[i]) acc_front = unit for idx, nxt_v in enumerate(tree[v]): res = addnode(merge(acc_front, acc_back[idx]), v) sub[nxt_v][idxs[v][idx]] = res acc_front = merge(acc_front, sub[v][idx]) noderes[v] = addnode(acc_front, v) return sub return noderes def topological_sorted(tree, root=None): n = len(tree) par = [-1] * n tp_order = [] for v in range(n): if par[v] != -1 or (root is not None and v != root): continue stack = [v] while stack: v = stack.pop() tp_order.append(v) for nxt_v, _ in tree[v]: if nxt_v == par[v]: continue par[nxt_v] = v stack.append(nxt_v) return tp_order, par from bisect import bisect_left class SortedSetBIT: def __init__(self, cands): self.array = sorted(set(cands)) self.comp = {val: i for i, val in enumerate(self.array)} self.size = len(self.array) self.cnt = 0 self.bit = [0] * (self.size + 1) def __contains__(self, val): return self.count(val, val + 1) > 0 def __len__(self): return self.cnt def _count(self, i): res = 0 while i > 0: res += self.bit[i] i -= i & -i return res def add(self, val): if val in self: return False i = self.comp[val] + 1 while i <= self.size: self.bit[i] += 1 i += i & -i self.cnt += 1 return True def remove(self, val): if val not in self: return False i = self.comp[val] + 1 while i <= self.size: self.bit[i] -= 1 i += i & -i self.cnt -= 1 return True def count(self, vl, vr): l = bisect_left(self.array, vl) r = bisect_left(self.array, vr) return self._count(r) - self._count(l) def kth_smallest(self, k): if not(0 <= k < self.cnt): raise IndexError idx = 0 k += 1 d = 1 << self.size.bit_length() while d: if idx + d <= self.size and self.bit[idx + d] < k: k -= self.bit[idx + d] idx += d d >>= 1 return self.array[idx] def kth_largest(self, k): return self.kth_smallest(self.cnt - k - 1) def prev_val(self, upper): upper = bisect_left(self.array, upper) k = self._count(upper) - 1 return None if k == -1 else self.kth_smallest(k) def next_val(self, lower): lower = bisect_left(self.array, lower) k = self._count(lower) return None if k == self.cnt else self.kth_smallest(k) def all_dump(self): res = self.bit[:] for i in reversed(range(1, self.size)): if i + (i & -i) > self.size: continue res[i + (i & -i)] -= res[i] return [self.array[i] for i, flag in enumerate(res[1:]) if flag] class SortedMultiSetBIT(SortedSetBIT): def __init__(self, cands): super().__init__(cands) def add(self, val): i = self.comp[val] + 1 while i <= self.size: self.bit[i] += 1 i += i & -i self.cnt += 1 return True def all_remove(self, val): if val not in self: return False i = self.comp[val] + 1 cnt = self._count(i) - self._count(i - 1) while i <= self.size: self.bit[i] -= cnt i += i & -i self.cnt -= cnt return True def all_dump(self): res = self.bit[:] for i in reversed(range(1, self.size)): if i + (i & -i) > self.size: continue res[i + (i & -i)] -= res[i] return [(self.array[i], cnt) for i, cnt in enumerate(res[1:]) if cnt] n = int(input()) edges = [list(map(int, input().split())) for _ in range(n - 1)] INF = 10 ** 18 es = [] tree = [[] for i in range(n)] max_cost = 0 for u, v, cost in edges: u -= 1 v -= 1 max_cost = max(max_cost, cost) es.append((u, v)) tree[u].append((v, cost)) tree[v].append((u, cost)) vals = [0] * n for u, v, cost in edges: u -= 1 v -= 1 if cost == max_cost: vals[u] = 1 vals[v] = 1 unit = 0 merge = lambda a, b: a + b addnode = lambda val, v: val + vals[v] sub = rerooting(n, es, unit, merge, addnode) for v, res in enumerate(sub): if len(res) <= 1: continue res = sorted(res, reverse=True) if res[1] != 0: vals[v] = 1 is_path = True min_cost = INF ends = [] for v in range(n): if vals[v] == 0: continue cnt = 0 for nxt_v, cost in tree[v]: if vals[nxt_v] == 1: min_cost = min(min_cost, cost) cnt += 1 if cnt > 2: is_path = False if cnt == 1: ends.append(v) if not is_path: print(max_cost) exit() uf = UnionFind(n) for u, v in es: if vals[u] == 1 and vals[v] == 1: continue else: uf.merge(u, v) cands = [] for u, v, cost in edges: u -= 1 v -= 1 if vals[u] == 1 and vals[v] == 1: continue cands.append(cost) st_set = SortedMultiSetBIT(cands + [0]) for val in cands: st_set.add(val) st_set.add(0) tree1 = None root1 = None tree2 = None root2 = None for gp in uf.groups(): if ends[0] in gp or ends[1] in gp: if ends[0] in gp: rt = ends[0] else: rt = ends[1] mapping = {val: i for i, val in enumerate(sorted(gp))} tr = [[] for i in range(len(gp))] st = set(gp) for u, v, cost in edges: u -= 1 v -= 1 if u in st and v in st: tr[mapping[u]].append((mapping[v], cost)) tr[mapping[v]].append((mapping[u], cost)) if root1 is None: tree1 = tr root1 = mapping[rt] else: tree2 = tr root2 = mapping[rt] e_val1 = {} e_val2 = {} for v in range(len(tree1)): for nxt_v, cost in tree1[v]: e_val1[v, nxt_v] = cost e_val1[nxt_v, v] = cost for v in range(len(tree2)): for nxt_v, cost in tree2[v]: e_val2[v, nxt_v] = cost e_val2[nxt_v, v] = cost _, par1 = topological_sorted(tree1, root1) _, par2 = topological_sorted(tree2, root2) order = set([]) map1 = {} map2 = {} for v in range(len(tree1)): for nxt_v, cost in tree1[v]: order.add(cost) if cost not in map1: map1[cost] = [] map1[cost].append(v) map1[cost].append(nxt_v) for v in range(len(tree2)): for nxt_v, cost in tree2[v]: order.add(cost) if cost not in map2: map2[cost] = [] map2[cost].append(v) map2[cost].append(nxt_v) order = sorted(order, reverse=True) ans = max(max_cost - min_cost, st_set.kth_largest(0)) used1 = [False] * len(tree1) used1[root1] = True end1 = root1 used2 = [False] * len(tree2) used2[root2] = True end2 = root2 for val in order: if val in map1: vs = set(map1[val]) for v in vs: if used1[v]: continue tmp_v = v while True: par_v = par1[tmp_v] if not used1[tmp_v]: min_cost = min(min_cost, e_val1[tmp_v, par_v]) st_set.remove(e_val1[tmp_v, par_v]) used1[tmp_v] = True tmp_v = par_v else: break if tmp_v != end1: print(ans) exit() end1 = v if val in map2: vs = set(map2[val]) for v in vs: if used2[v]: continue tmp_v = v while True: par_v = par2[tmp_v] if not used2[tmp_v]: min_cost = min(min_cost, e_val2[tmp_v, par_v]) st_set.remove(e_val2[tmp_v, par_v]) used2[tmp_v] = True tmp_v = par_v else: break if tmp_v != end2: print(ans) exit() end2 = v ans = min(ans, max(max_cost - min_cost, st_set.kth_largest(0))) print(ans)