結果

問題 No.1002 Twotone
ユーザー lam6er
提出日時 2025-04-09 20:56:09
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 2,998 bytes
コンパイル時間 260 ms
コンパイル使用メモリ 82,360 KB
実行使用メモリ 849,428 KB
最終ジャッジ日時 2025-04-09 20:57:49
合計ジャッジ時間 7,460 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 1 MLE * 1 -- * 31
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
from collections import defaultdict

sys.setrecursionlimit(1 << 25)

class DSU:
    def __init__(self, n):
        self.parent = list(range(n + 1))  # 1-based indexing
        self.size = [1] * (n + 1)
    
    def find(self, x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]
    
    def union(self, x, y):
        x_root = self.find(x)
        y_root = self.find(y)
        if x_root == y_root:
            return
        if self.size[x_root] < self.size[y_root]:
            x_root, y_root = y_root, x_root
        self.parent[y_root] = x_root
        self.size[x_root] += self.size[y_root]

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, K = int(input[ptr]), int(input[ptr+1])
    ptr +=2
    edges = []
    color_edges = defaultdict(list)
    for _ in range(N-1):
        u = int(input[ptr])
        v = int(input[ptr+1])
        c = int(input[ptr+2])
        ptr +=3
        edges.append((u, v, c))
        color_edges[c].append((u, v))
    
    # Precompute total_a for each color
    color_dsu = {}
    total_a = defaultdict(int)
    colors = list(color_edges.keys())
    for c in colors:
        dsu = DSU(N)
        for u, v in color_edges[c]:
            dsu.union(u, v)
        # find all roots and their sizes
        components = defaultdict(int)
        for node in range(1, N+1):
            root = dsu.find(node)
            components[root] = dsu.size[root]
        total = 0
        for s in components.values():
            total += s * (s - 1) // 2
        total_a[c] = total
        color_dsu[c] = dsu
    
    # Collect all unique pairs of colors
    color_list = list(color_edges.keys())
    answer = 0
    processed_pairs = set()
    for i in range(len(color_list)):
        c1 = color_list[i]
        for j in range(i+1, len(color_list)):
            c2 = color_list[j]
            if (c1, c2) in processed_pairs or (c2, c1) in processed_pairs:
                continue
            processed_pairs.add((c1, c2))
            # Merge edges of c1 and c2
            dsu = DSU(N)
            for u, v in color_edges[c1] + color_edges[c2]:
                dsu.union(u, v)
            components = defaultdict(int)
            for node in range(1, N+1):
                root = dsu.find(node)
                components[root] = dsu.size[root]
            total_cd = 0
            for s in components.values():
                total_cd += s * (s - 1) // 2
            contrib = total_cd - total_a[c1] - total_a[c2]
            if contrib >0:
                answer += contrib
    
    # Handle pairs where one color has edges and the other has none
    # But the previous loop already includes pairs where both colors have edges
    # Additionally, check if any color pairs where one has edges and the other not
    # ToDo: Possible missing cases if a color has no edges but paired with others
    
    print(answer)

if __name__ == '__main__':
    main()
0