結果
| 問題 |
No.1002 Twotone
|
| コンテスト | |
| ユーザー |
qwewe
|
| 提出日時 | 2025-05-14 13:22:07 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 5,785 bytes |
| コンパイル時間 | 196 ms |
| コンパイル使用メモリ | 82,144 KB |
| 実行使用メモリ | 155,852 KB |
| 最終ジャッジ日時 | 2025-05-14 13:24:59 |
| 合計ジャッジ時間 | 34,847 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 7 TLE * 4 -- * 22 |
ソースコード
import sys
sys.setrecursionlimit(4 * 10**5)
def solve():
N, K = map(int, sys.stdin.readline().split())
adj = [[] for _ in range(N)]
for _ in range(N - 1):
u, v, c = map(int, sys.stdin.readline().split())
u -= 1
v -= 1
adj[u].append((v, c))
adj[v].append((u, c))
total_ans = 0
is_removed_node = [False] * N
subtree_sizes_for_centroid_finding = [0] * N # Used by get_subtree_sizes_and_find_centroid
def calculate_contribution(path_summary):
current_ans = 0
val_n0 = path_summary.get(frozenset(), 0)
n1_items = []
n2_items = []
for S, count in path_summary.items():
if not S:
continue
if len(S) == 1:
n1_items.append((list(S)[0], count))
elif len(S) == 2:
n2_items.append((S, count))
if val_n0 > 0:
for _, count_ab in n2_items:
current_ans += val_n0 * count_ab
sum_counts_n1 = 0
sum_sq_counts_n1 = 0
for _, count in n1_items:
sum_counts_n1 += count
sum_sq_counts_n1 += count * count
if n1_items:
current_ans += (sum_counts_n1 * sum_counts_n1 - sum_sq_counts_n1) // 2
if n1_items and n2_items:
sum_n2_for_color = {}
for S_xy, count_xy in n2_items:
c_list = list(S_xy) # S_xy is frozenset of 2 colors
c1, c2 = c_list[0], c_list[1]
sum_n2_for_color[c1] = sum_n2_for_color.get(c1, 0) + count_xy
sum_n2_for_color[c2] = sum_n2_for_color.get(c2, 0) + count_xy
for ca, count_a in n1_items:
current_ans += count_a * sum_n2_for_color.get(ca, 0)
for _, count_ab in n2_items:
current_ans += count_ab * (count_ab - 1) // 2
return current_ans
_collected_paths_map_storage = {} # Reused to avoid repeated dict creation overhead
def dfs_collect_paths(u, p, current_path_colors_from_start_node, target_map):
target_map[current_path_colors_from_start_node] = \
target_map.get(current_path_colors_from_start_node, 0) + 1
for v, color_uv in adj[u]:
if v == p or is_removed_node[v]:
continue
new_path_colors = current_path_colors_from_start_node.union(frozenset([color_uv]))
if len(new_path_colors) <= 2:
dfs_collect_paths(v, u, new_path_colors, target_map)
_component_nodes_bfs_q = [0] * N # Reusable queue for BFS
def get_component_nodes_and_size(entry_node):
# BFS to find all nodes in the current component
q_ptr = 0
_component_nodes_bfs_q[q_ptr] = entry_node
q_ptr += 1
head = 0
visited_in_bfs = {entry_node}
while head < q_ptr:
curr = _component_nodes_bfs_q[head]
head += 1
for neighbor, _ in adj[curr]:
if not is_removed_node[neighbor] and neighbor not in visited_in_bfs:
visited_in_bfs.add(neighbor)
_component_nodes_bfs_q[q_ptr] = neighbor
q_ptr += 1
# Nodes are _component_nodes_bfs_q[0...q_ptr-1]
return q_ptr # This is component_size
# DFS pass to calculate subtree sizes for centroid finding
def _dfs_calc_sizes_for_centroid(u, p):
subtree_sizes_for_centroid_finding[u] = 1
for v, _ in adj[u]:
if v == p or is_removed_node[v]:
continue
_dfs_calc_sizes_for_centroid(v, u)
subtree_sizes_for_centroid_finding[u] += subtree_sizes_for_centroid_finding[v]
# DFS pass to find the centroid
def _dfs_find_centroid(u, p, component_total_size):
for v, _ in adj[u]:
if v == p or is_removed_node[v]:
continue
if subtree_sizes_for_centroid_finding[v] * 2 > component_total_size:
return _dfs_find_centroid(v, u, component_total_size)
return u # u is the centroid
def cd_solve(entry_node_in_component):
nonlocal total_ans
component_size = get_component_nodes_and_size(entry_node_in_component)
_dfs_calc_sizes_for_centroid(entry_node_in_component, -1)
centroid = _dfs_find_centroid(entry_node_in_component, -1, component_size)
_collected_paths_map_storage.clear()
dfs_collect_paths(centroid, -1, frozenset(), _collected_paths_map_storage)
total_ans += calculate_contribution(_collected_paths_map_storage)
is_removed_node[centroid] = True
for v_neighbor, color_cv, in adj[centroid]:
if is_removed_node[v_neighbor]:
continue
_collected_paths_map_storage.clear()
dfs_collect_paths(v_neighbor, centroid, frozenset(), _collected_paths_map_storage) # Paths from v_neighbor
paths_to_centroid_via_v_neighbor = {}
for s_vx, count in _collected_paths_map_storage.items(): # s_vx is path v_neighbor to x
s_cx = s_vx.union(frozenset([color_cv])) # s_cx is path centroid to x (via v_neighbor)
if len(s_cx) <= 2:
paths_to_centroid_via_v_neighbor[s_cx] = \
paths_to_centroid_via_v_neighbor.get(s_cx, 0) + count
total_ans -= calculate_contribution(paths_to_centroid_via_v_neighbor)
for v_neighbor, _, in adj[centroid]:
if not is_removed_node[v_neighbor]:
cd_solve(v_neighbor)
if N > 0:
cd_solve(0)
print(total_ans)
solve()
qwewe