結果

問題 No.1833 Subway Planning
ユーザー neterukunneterukun
提出日時 2022-02-04 23:08:40
言語 PyPy3
(7.3.15)
結果
RE  
実行時間 -
コード長 3,635 bytes
コンパイル時間 159 ms
コンパイル使用メモリ 82,432 KB
実行使用メモリ 237,444 KB
最終ジャッジ日時 2024-06-11 12:43:06
合計ジャッジ時間 20,163 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 RE -
testcase_01 RE -
testcase_02 RE -
testcase_03 RE -
testcase_04 RE -
testcase_05 RE -
testcase_06 RE -
testcase_07 RE -
testcase_08 RE -
testcase_09 RE -
testcase_10 RE -
testcase_11 AC 1,005 ms
233,488 KB
testcase_12 AC 983 ms
235,092 KB
testcase_13 RE -
testcase_14 RE -
testcase_15 RE -
testcase_16 AC 1,139 ms
214,724 KB
testcase_17 RE -
testcase_18 RE -
testcase_19 AC 1,237 ms
214,716 KB
testcase_20 RE -
testcase_21 RE -
testcase_22 RE -
testcase_23 RE -
testcase_24 RE -
testcase_25 RE -
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
input = sys.stdin.buffer.readline


class UnionFind:
    def __init__(self, n):
        self.parent = [-1] * n
        self.n = n
        self.cnt = n

    def root(self, x):
        if self.parent[x] < 0:
            return x
        else:
            self.parent[x] = self.root(self.parent[x])
            return self.parent[x]

    def merge(self, x, y):
        x = self.root(x)
        y = self.root(y)
        if x == y:
            return False
        if self.parent[x] > self.parent[y]:
            x, y = y, x
        self.parent[x] += self.parent[y]
        self.parent[y] = x
        self.cnt -= 1
        return True

    def same(self, x, y):
        return self.root(x) == self.root(y)

    def size(self, x):
        return -self.parent[self.root(x)]

    def count(self):
        return self.cnt

    def groups(self):
        res = [[] for _ in range(self.n)]
        for i in range(self.n):
            res[self.root(i)].append(i)
        return [group for group in res if group]


def rerooting(n, edges, unit, merge, addnode):
    tree = [[] for i in range(n)]
    idxs = [[] for i in range(n)]
    for u, v in edges:
        idxs[u].append(len(tree[v]))
        idxs[v].append(len(tree[u]))
        tree[u].append(v)
        tree[v].append(u)
    sub = [[unit] * len(tree[v]) for v in range(n)]
    noderes = [unit] * n

    # topological sort
    tp_order = []
    par = [-1] * n
    for root in range(n):
        if par[root] != -1:
            continue
        stack = [root]
        while stack:
            v = stack.pop()
            tp_order.append(v)
            for nxt_v in tree[v]:
                if nxt_v == par[v]:
                    continue
                par[nxt_v] = v
                stack.append(nxt_v)

    # tree DP
    for v in reversed(tp_order[1:]):
        res = unit
        par_idx = -1
        for idx, nxt_v in enumerate(tree[v]):
            if nxt_v == par[v]:
                par_idx = idx
                continue
            res = merge(res, sub[v][idx])
        if par_idx != -1:
            sub[par[v]][idxs[v][par_idx]] = addnode(res, v)

    # rerooting DP
    for v in tp_order:
        acc_back = [unit] * len(tree[v])
        for i in reversed(range(1, len(acc_back))):
            acc_back[i - 1] = merge(sub[v][i], acc_back[i])
        acc_front = unit
        for idx, nxt_v in enumerate(tree[v]):
            res = addnode(merge(acc_front, acc_back[idx]), v)
            sub[nxt_v][idxs[v][idx]] = res
            acc_front = merge(acc_front, sub[v][idx])
        noderes[v] = addnode(acc_front, v)
    return sub
    return noderes


n = int(input())
edges = [list(map(int, input().split())) for _ in range(n - 1)]


es = []
tree = [[] for i in range(n)]
max_cost = 0
for u, v, cost in edges:
    u -= 1
    v -= 1
    max_cost = max(max_cost, cost)
    es.append((u, v))
    tree[u].append((v, cost))
    tree[v].append((u, cost))

vals = [0] * n 
for u, v, cost in edges:
    u -= 1
    v -= 1
    if cost == max_cost:
        vals[u] = 1
        vals[v] = 1

unit = 0
merge = lambda a, b: a + b
addnode = lambda val, v: val + vals[v]

sub = rerooting(n, es, unit, merge, addnode)
for v, res in enumerate(sub):
    if len(res) <= 1:
        continue
    res = sorted(res, reverse=True)
    if res[1] != 0:
        vals[v] = 1

is_path = True
ends = []
for v in range(n):
    if vals[v] == 0:
        continue
    cnt = 0
    for nxt_v, cost in tree[v]:
        if vals[nxt_v] >= 1:
            cnt += 1
    if cnt > 2:
        is_path = False
    if cnt == 1:
        ends.append(v)

if not is_path:
    print(max_cost)
    exit()

else:
    print(re)
0