結果

問題 No.1002 Twotone
ユーザー qwewe
提出日時 2025-05-14 13:22:07
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 5,785 bytes
コンパイル時間 196 ms
コンパイル使用メモリ 82,144 KB
実行使用メモリ 155,852 KB
最終ジャッジ日時 2025-05-14 13:24:59
合計ジャッジ時間 34,847 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 7 TLE * 4 -- * 22
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

sys.setrecursionlimit(4 * 10**5) 

def solve():
    N, K = map(int, sys.stdin.readline().split())
    adj = [[] for _ in range(N)]
    for _ in range(N - 1):
        u, v, c = map(int, sys.stdin.readline().split())
        u -= 1
        v -= 1
        adj[u].append((v, c))
        adj[v].append((u, c))

    total_ans = 0
    is_removed_node = [False] * N
    
    subtree_sizes_for_centroid_finding = [0] * N # Used by get_subtree_sizes_and_find_centroid

    def calculate_contribution(path_summary):
        current_ans = 0
        val_n0 = path_summary.get(frozenset(), 0)
        
        n1_items = []
        n2_items = []
        
        for S, count in path_summary.items():
            if not S: 
                continue
            if len(S) == 1:
                n1_items.append((list(S)[0], count))
            elif len(S) == 2:
                n2_items.append((S, count))

        if val_n0 > 0:
            for _, count_ab in n2_items:
                current_ans += val_n0 * count_ab
        
        sum_counts_n1 = 0
        sum_sq_counts_n1 = 0
        for _, count in n1_items:
            sum_counts_n1 += count
            sum_sq_counts_n1 += count * count
        
        if n1_items:
             current_ans += (sum_counts_n1 * sum_counts_n1 - sum_sq_counts_n1) // 2
             
        if n1_items and n2_items:
            sum_n2_for_color = {} 
            for S_xy, count_xy in n2_items:
                c_list = list(S_xy) # S_xy is frozenset of 2 colors
                c1, c2 = c_list[0], c_list[1]
                sum_n2_for_color[c1] = sum_n2_for_color.get(c1, 0) + count_xy
                sum_n2_for_color[c2] = sum_n2_for_color.get(c2, 0) + count_xy
            
            for ca, count_a in n1_items:
                current_ans += count_a * sum_n2_for_color.get(ca, 0)

        for _, count_ab in n2_items:
            current_ans += count_ab * (count_ab - 1) // 2
            
        return current_ans

    _collected_paths_map_storage = {} # Reused to avoid repeated dict creation overhead

    def dfs_collect_paths(u, p, current_path_colors_from_start_node, target_map):
        target_map[current_path_colors_from_start_node] = \
            target_map.get(current_path_colors_from_start_node, 0) + 1
        
        for v, color_uv in adj[u]:
            if v == p or is_removed_node[v]:
                continue
            
            new_path_colors = current_path_colors_from_start_node.union(frozenset([color_uv]))
            if len(new_path_colors) <= 2:
                dfs_collect_paths(v, u, new_path_colors, target_map)

    _component_nodes_bfs_q = [0] * N # Reusable queue for BFS

    def get_component_nodes_and_size(entry_node):
        # BFS to find all nodes in the current component
        q_ptr = 0
        _component_nodes_bfs_q[q_ptr] = entry_node
        q_ptr += 1
        
        head = 0
        visited_in_bfs = {entry_node}

        while head < q_ptr:
            curr = _component_nodes_bfs_q[head]
            head += 1
            for neighbor, _ in adj[curr]:
                if not is_removed_node[neighbor] and neighbor not in visited_in_bfs:
                    visited_in_bfs.add(neighbor)
                    _component_nodes_bfs_q[q_ptr] = neighbor
                    q_ptr += 1
        # Nodes are _component_nodes_bfs_q[0...q_ptr-1]
        return q_ptr # This is component_size

    # DFS pass to calculate subtree sizes for centroid finding
    def _dfs_calc_sizes_for_centroid(u, p):
        subtree_sizes_for_centroid_finding[u] = 1
        for v, _ in adj[u]:
            if v == p or is_removed_node[v]:
                continue
            _dfs_calc_sizes_for_centroid(v, u)
            subtree_sizes_for_centroid_finding[u] += subtree_sizes_for_centroid_finding[v]

    # DFS pass to find the centroid
    def _dfs_find_centroid(u, p, component_total_size):
        for v, _ in adj[u]:
            if v == p or is_removed_node[v]:
                continue
            if subtree_sizes_for_centroid_finding[v] * 2 > component_total_size:
                return _dfs_find_centroid(v, u, component_total_size)
        return u # u is the centroid

    def cd_solve(entry_node_in_component):
        nonlocal total_ans
        
        component_size = get_component_nodes_and_size(entry_node_in_component)
        
        _dfs_calc_sizes_for_centroid(entry_node_in_component, -1)
        centroid = _dfs_find_centroid(entry_node_in_component, -1, component_size)
        
        _collected_paths_map_storage.clear()
        dfs_collect_paths(centroid, -1, frozenset(), _collected_paths_map_storage)
        total_ans += calculate_contribution(_collected_paths_map_storage)

        is_removed_node[centroid] = True

        for v_neighbor, color_cv, in adj[centroid]:
            if is_removed_node[v_neighbor]:
                continue
            
            _collected_paths_map_storage.clear()
            dfs_collect_paths(v_neighbor, centroid, frozenset(), _collected_paths_map_storage) # Paths from v_neighbor
            
            paths_to_centroid_via_v_neighbor = {}
            for s_vx, count in _collected_paths_map_storage.items(): # s_vx is path v_neighbor to x
                s_cx = s_vx.union(frozenset([color_cv])) # s_cx is path centroid to x (via v_neighbor)
                if len(s_cx) <= 2:
                    paths_to_centroid_via_v_neighbor[s_cx] = \
                        paths_to_centroid_via_v_neighbor.get(s_cx, 0) + count
            
            total_ans -= calculate_contribution(paths_to_centroid_via_v_neighbor)

        for v_neighbor, _, in adj[centroid]:
            if not is_removed_node[v_neighbor]:
                cd_solve(v_neighbor)
                
    if N > 0:
      cd_solve(0)
    print(total_ans)

solve()
0