結果

問題 No.1002 Twotone
ユーザー lam6er
提出日時 2025-03-20 20:54:53
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,937 bytes
コンパイル時間 170 ms
コンパイル使用メモリ 82,252 KB
実行使用メモリ 389,776 KB
最終ジャッジ日時 2025-03-20 20:55:47
合計ジャッジ時間 20,784 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 9 TLE * 1 -- * 23
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
from collections import defaultdict

sys.setrecursionlimit(1 << 25)

def main():
    input = sys.stdin.read().split()
    idx = 0
    N, K = int(input[idx]), int(input[idx+1])
    idx +=2

    edges = [[] for _ in range(N+1)]
    color_edges = defaultdict(list)
    for i in range(N-1):
        u = int(input[idx])
        v = int(input[idx+1])
        c = int(input[idx+2])
        idx +=3
        edges[u].append((v, c, i))
        edges[v].append((u, c, i))
        color_edges[c].append((u, v))

    # Compute sc for each color
    sc = dict()
    for c in color_edges:
        parent = {}
        size = {}

        def find(u):
            if parent[u] != u:
                parent[u] = find(parent[u])
            return parent[u]

        def union(u, v):
            u_root = find(u)
            v_root = find(v)
            if u_root == v_root:
                return
            if size[u_root] < size[v_root]:
                u_root, v_root = v_root, u_root
            parent[v_root] = u_root
            size[u_root] += size[v_root]

        for u, v in color_edges[c]:
            if u not in parent:
                parent[u] = u
                size[u] = 1
            if v not in parent:
                parent[v] = v
                size[v] = 1
            union(u, v)

        components = defaultdict(int)
        for u in parent:
            root = find(u)
            components[root] = size[root]

        total = 0
        for s in components.values():
            total += s * (s - 1) // 2
        sc[c] = total

    # Collect all (c, d) pairs from edges around each node
    pairs = set()
    node_color = defaultdict(set)
    for u in range(1, N+1):
        colors = set()
        for _, c, _ in edges[u]:
            colors.add(c)
        colors = list(colors)
        for i in range(len(colors)):
            for j in range(i+1, len(colors)):
                c1 = colors[i]
                c2 = colors[j]
                if c1 < c2:
                    pairs.add((c1, c2))
                else:
                    pairs.add((c2, c1))
        for c in colors:
            if len(colors) >= 1:
                for c2 in colors:
                    if c != c2:
                        if c < c2:
                            pairs.add((c, c2))
                        else:
                            pairs.add((c2, c))

    valid = 0

    # Process each unique pair (c, d) in pairs, ensuring c < d
    processed = set()
    for (c1, c2) in pairs:
        if c1 == c2:
            continue
        a, b = min(c1, c2), max(c1, c2)
        if (a, b) in processed:
            continue
        processed.add((a, b))

        combined_edges = []
        if a in color_edges:
            combined_edges.extend(color_edges[a])
        if b in color_edges:
            combined_edges.extend(color_edges[b])

        parent = {}
        size = {}

        def find(u):
            if parent[u] != u:
                parent[u] = find(parent[u])
            return parent[u]

        def union(u, v):
            u_root = find(u)
            v_root = find(v)
            if u_root == v_root:
                return
            if size[u_root] < size[v_root]:
                u_root, v_root = v_root, u_root
            parent[v_root] = u_root
            size[u_root] += size[v_root]

        for u, v in combined_edges:
            if u not in parent:
                parent[u] = u
                size[u] = 1
            if v not in parent:
                parent[v] = v
                size[v] = 1
            union(u, v)

        components = defaultdict(int)
        for node in parent:
            root = find(node)
            components[root] = size.get(root, 1)

        total_ab = sum(m * (m-1) // 2 for m in components.values())
        sa = sc.get(a, 0)
        sb = sc.get(b, 0)
        valid += (total_ab - sa - sb)

    print(valid)

if __name__ == '__main__':
    main()
0