結果

問題 No.1833 Subway Planning
ユーザー neterukunneterukun
提出日時 2022-02-04 23:57:30
言語 PyPy3
(7.3.15)
結果
RE  
実行時間 -
コード長 10,619 bytes
コンパイル時間 397 ms
コンパイル使用メモリ 82,304 KB
実行使用メモリ 375,540 KB
最終ジャッジ日時 2024-06-11 13:08:11
合計ジャッジ時間 33,662 ms
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 RE -
testcase_01 AC 41 ms
54,784 KB
testcase_02 AC 42 ms
55,040 KB
testcase_03 AC 42 ms
54,528 KB
testcase_04 RE -
testcase_05 RE -
testcase_06 RE -
testcase_07 RE -
testcase_08 RE -
testcase_09 RE -
testcase_10 AC 1,893 ms
375,540 KB
testcase_11 AC 926 ms
230,088 KB
testcase_12 AC 862 ms
231,364 KB
testcase_13 AC 1,849 ms
360,300 KB
testcase_14 AC 1,245 ms
361,240 KB
testcase_15 AC 1,824 ms
358,364 KB
testcase_16 AC 936 ms
215,156 KB
testcase_17 AC 2,336 ms
347,336 KB
testcase_18 AC 2,241 ms
332,672 KB
testcase_19 AC 1,003 ms
213,076 KB
testcase_20 AC 837 ms
189,572 KB
testcase_21 AC 2,353 ms
335,372 KB
testcase_22 AC 2,098 ms
344,212 KB
testcase_23 RE -
testcase_24 AC 41 ms
54,912 KB
testcase_25 AC 41 ms
55,040 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
input = sys.stdin.buffer.readline


class UnionFind:
    def __init__(self, n):
        self.parent = [-1] * n
        self.n = n
        self.cnt = n

    def root(self, x):
        if self.parent[x] < 0:
            return x
        else:
            self.parent[x] = self.root(self.parent[x])
            return self.parent[x]

    def merge(self, x, y):
        x = self.root(x)
        y = self.root(y)
        if x == y:
            return False
        if self.parent[x] > self.parent[y]:
            x, y = y, x
        self.parent[x] += self.parent[y]
        self.parent[y] = x
        self.cnt -= 1
        return True

    def same(self, x, y):
        return self.root(x) == self.root(y)

    def size(self, x):
        return -self.parent[self.root(x)]

    def count(self):
        return self.cnt

    def groups(self):
        res = [[] for _ in range(self.n)]
        for i in range(self.n):
            res[self.root(i)].append(i)
        return [group for group in res if group]


def rerooting(n, edges, unit, merge, addnode):
    tree = [[] for i in range(n)]
    idxs = [[] for i in range(n)]
    for u, v in edges:
        idxs[u].append(len(tree[v]))
        idxs[v].append(len(tree[u]))
        tree[u].append(v)
        tree[v].append(u)
    sub = [[unit] * len(tree[v]) for v in range(n)]
    noderes = [unit] * n

    # topological sort
    tp_order = []
    par = [-1] * n
    for root in range(n):
        if par[root] != -1:
            continue
        stack = [root]
        while stack:
            v = stack.pop()
            tp_order.append(v)
            for nxt_v in tree[v]:
                if nxt_v == par[v]:
                    continue
                par[nxt_v] = v
                stack.append(nxt_v)

    # tree DP
    for v in reversed(tp_order[1:]):
        res = unit
        par_idx = -1
        for idx, nxt_v in enumerate(tree[v]):
            if nxt_v == par[v]:
                par_idx = idx
                continue
            res = merge(res, sub[v][idx])
        if par_idx != -1:
            sub[par[v]][idxs[v][par_idx]] = addnode(res, v)

    # rerooting DP
    for v in tp_order:
        acc_back = [unit] * len(tree[v])
        for i in reversed(range(1, len(acc_back))):
            acc_back[i - 1] = merge(sub[v][i], acc_back[i])
        acc_front = unit
        for idx, nxt_v in enumerate(tree[v]):
            res = addnode(merge(acc_front, acc_back[idx]), v)
            sub[nxt_v][idxs[v][idx]] = res
            acc_front = merge(acc_front, sub[v][idx])
        noderes[v] = addnode(acc_front, v)
    return sub
    return noderes


def topological_sorted(tree, root=None):
    n = len(tree)
    par = [-1] * n
    tp_order = []
    for v in range(n):
        if par[v] != -1 or (root is not None and v != root):
            continue
        stack = [v]
        while stack:
            v = stack.pop()
            tp_order.append(v)
            for nxt_v, _ in tree[v]:
                if nxt_v == par[v]:
                    continue
                par[nxt_v] = v
                stack.append(nxt_v)
    return tp_order, par


from bisect import bisect_left


class SortedSetBIT:
    def __init__(self, cands):
        self.array = sorted(set(cands))
        self.comp = {val: i for i, val in enumerate(self.array)}
        self.size = len(self.array)
        self.cnt = 0
        self.bit = [0] * (self.size + 1)

    def __contains__(self, val):
        return self.count(val, val + 1) > 0

    def __len__(self):
        return self.cnt

    def _count(self, i):
        res = 0
        while i > 0:
            res += self.bit[i]
            i -= i & -i
        return res

    def add(self, val):
        if val in self:
            return False
        i = self.comp[val] + 1
        while i <= self.size:
            self.bit[i] += 1
            i += i & -i
        self.cnt += 1
        return True

    def remove(self, val):
        if val not in self:
            return False
        i = self.comp[val] + 1
        while i <= self.size:
            self.bit[i] -= 1
            i += i & -i
        self.cnt -= 1
        return True

    def count(self, vl, vr):
        l = bisect_left(self.array, vl)
        r = bisect_left(self.array, vr)
        return self._count(r) - self._count(l)

    def kth_smallest(self, k):
        if not(0 <= k < self.cnt):
            raise IndexError
        idx = 0
        k += 1
        d = 1 << self.size.bit_length()
        while d:
            if idx + d <= self.size and self.bit[idx + d] < k:
                k -= self.bit[idx + d]
                idx += d
            d >>= 1
        return self.array[idx]

    def kth_largest(self, k):
        return self.kth_smallest(self.cnt - k - 1)

    def prev_val(self, upper):
        upper = bisect_left(self.array, upper)
        k = self._count(upper) - 1
        return None if k == -1 else self.kth_smallest(k)

    def next_val(self, lower):
        lower = bisect_left(self.array, lower)
        k = self._count(lower)
        return None if k == self.cnt else self.kth_smallest(k)

    def all_dump(self):
        res = self.bit[:]
        for i in reversed(range(1, self.size)):
            if i + (i & -i) > self.size:
                continue
            res[i + (i & -i)] -= res[i]
        return [self.array[i] for i, flag in enumerate(res[1:]) if flag]

class SortedMultiSetBIT(SortedSetBIT):
    def __init__(self, cands):
        super().__init__(cands)

    def add(self, val):
        i = self.comp[val] + 1
        while i <= self.size:
            self.bit[i] += 1
            i += i & -i
        self.cnt += 1
        return True

    def all_remove(self, val):
        if val not in self:
            return False
        i = self.comp[val] + 1
        cnt = self._count(i) - self._count(i - 1)
        while i <= self.size:
            self.bit[i] -= cnt
            i += i & -i
        self.cnt -= cnt
        return True

    def all_dump(self):
        res = self.bit[:]
        for i in reversed(range(1, self.size)):
            if i + (i & -i) > self.size:
                continue
            res[i + (i & -i)] -= res[i]
        return [(self.array[i], cnt) for i, cnt in enumerate(res[1:]) if cnt]


n = int(input())
edges = [list(map(int, input().split())) for _ in range(n - 1)]
INF = 10 ** 18


es = []
tree = [[] for i in range(n)]
max_cost = 0
for u, v, cost in edges:
    u -= 1
    v -= 1
    max_cost = max(max_cost, cost)
    es.append((u, v))
    tree[u].append((v, cost))
    tree[v].append((u, cost))

vals = [0] * n 
for u, v, cost in edges:
    u -= 1
    v -= 1
    if cost == max_cost:
        vals[u] = 1
        vals[v] = 1

unit = 0
merge = lambda a, b: a + b
addnode = lambda val, v: val + vals[v]

sub = rerooting(n, es, unit, merge, addnode)
for v, res in enumerate(sub):
    if len(res) <= 1:
        continue
    res = sorted(res, reverse=True)
    if res[1] != 0:
        vals[v] = 1


is_path = True
min_cost = INF
ends = []
for v in range(n):
    if vals[v] == 0:
        continue
    cnt = 0
    for nxt_v, cost in tree[v]:
        if vals[nxt_v] == 1:
            min_cost = min(min_cost, cost)
            cnt += 1
    if cnt > 2:
        is_path = False
    if cnt == 1:
        ends.append(v)

if not is_path:
    print(max_cost)
    exit()


uf = UnionFind(n)
for u, v in es:
    if vals[u] == 1 and vals[v] == 1:
        continue
    else:
        uf.merge(u, v)

cands = []
for u, v, cost in edges:
    u -= 1
    v -= 1
    if vals[u] == 1 and vals[v] == 1:
        continue
    cands.append(cost)

st_set = SortedMultiSetBIT(cands)
for val in cands:
    st_set.add(val)


tree1 = None
root1 = None
tree2 = None
root2 = None
for gp in uf.groups():
    if ends[0] in gp or ends[1] in gp:
        if ends[0] in gp:
            rt = ends[0]
        else:
            rt = ends[1]
        mapping = {val: i for i, val in enumerate(sorted(gp))}
        tr = [[] for i in range(len(gp))]
        st = set(gp)
        for u, v, cost in edges:
            u -= 1
            v -= 1
            if u in st and v in st:
                tr[mapping[u]].append((mapping[v], cost))
                tr[mapping[v]].append((mapping[u], cost))
        if root1 is None:
            tree1 = tr
            root1 = mapping[rt]
        else:
            tree2 = tr
            root2 = mapping[rt]

e_val1 = {}
e_val2 = {}
for v in range(len(tree1)):
    for nxt_v, cost in tree1[v]:
        e_val1[v, nxt_v] = cost
        e_val1[nxt_v, v] = cost
for v in range(len(tree2)):
    for nxt_v, cost in tree2[v]:
        e_val2[v, nxt_v] = cost
        e_val2[nxt_v, v] = cost

_, par1 = topological_sorted(tree1, root1)
_, par2 = topological_sorted(tree2, root2)

order = set([])
map1 = {}
map2 = {}
for v in range(len(tree1)):
    for nxt_v, cost in tree1[v]:
        order.add(cost)
        if cost not in map1:
            map1[cost] = []
        map1[cost].append(v)
        map1[cost].append(nxt_v)
for v in range(len(tree2)):
    for nxt_v, cost in tree2[v]:
        order.add(cost)
        if cost not in map2:
            map2[cost] = []
        map2[cost].append(v)
        map2[cost].append(nxt_v)
order = sorted(order, reverse=True)


ans = max(max_cost - min_cost, st_set.kth_largest(0))
used1 = [False] * len(tree1)
used1[root1] = True
end1 = root1
used2 = [False] * len(tree2)
used2[root2] = True
end2 = root2


for val in order:
    if val in map1:
        vs = set(map1[val])
        for v in vs:
            if used1[v]:
                continue
            tmp_v = v
            while True:
                par_v = par1[tmp_v]
                if not used1[tmp_v]:
                    min_cost = min(min_cost, e_val1[tmp_v, par_v])
                    st_set.remove(e_val1[tmp_v, par_v])
                    used1[tmp_v] = True
                    tmp_v = par_v
                else:
                    break
            if tmp_v != end1:
                print(ans)
                exit()
            end1 = v

    if val in map2:
        vs = set(map2[val])
        for v in vs:
            if used2[v]:
                continue
            tmp_v = v
            while True:
                par_v = par2[tmp_v]
                if not used2[tmp_v]:
                    min_cost = min(min_cost, e_val2[tmp_v, par_v])
                    st_set.remove(e_val2[tmp_v, par_v])
                    used2[tmp_v] = True
                    tmp_v = par_v
                else:
                    break
            if tmp_v != end2:
                print(ans)
                exit()
            end2 = v
    ans = min(ans, max(max_cost - min_cost, st_set.kth_largest(0)))
print(ans)
0