結果
| 問題 |
No.1002 Twotone
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-20 19:03:23 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 3,937 bytes |
| コンパイル時間 | 258 ms |
| コンパイル使用メモリ | 82,184 KB |
| 実行使用メモリ | 283,680 KB |
| 最終ジャッジ日時 | 2025-03-20 19:04:39 |
| 合計ジャッジ時間 | 19,903 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 9 TLE * 1 -- * 23 |
ソースコード
import sys
from sys import stdin
from collections import defaultdict
sys.setrecursionlimit(1 << 25)
def main():
input = sys.stdin.read().split()
idx = 0
N, K = int(input[idx]), int(input[idx+1])
idx +=2
edges = [[] for _ in range(N+1)]
color_edges = defaultdict(list)
for i in range(N-1):
u = int(input[idx])
v = int(input[idx+1])
c = int(input[idx+2])
idx +=3
edges[u].append((v, c, i))
edges[v].append((u, c, i))
color_edges[c].append((u, v))
# Compute sc for each color
sc = dict()
for c in color_edges:
parent = {}
size = {}
def find(u):
if parent[u] != u:
parent[u] = find(parent[u])
return parent[u]
def union(u, v):
u_root = find(u)
v_root = find(v)
if u_root == v_root:
return
if size[u_root] < size[v_root]:
u_root, v_root = v_root, u_root
parent[v_root] = u_root
size[u_root] += size[v_root]
for u, v in color_edges[c]:
if u not in parent:
parent[u] = u
size[u] = 1
if v not in parent:
parent[v] = v
size[v] = 1
union(u, v)
components = defaultdict(int)
for u in parent:
root = find(u)
components[root] = size[root]
total = 0
for s in components.values():
total += s * (s - 1) // 2
sc[c] = total
# Collect all (c, d) pairs from edges around each node
pairs = set()
node_color = defaultdict(set)
for u in range(1, N+1):
colors = set()
for _, c, _ in edges[u]:
colors.add(c)
colors = list(colors)
for i in range(len(colors)):
for j in range(i+1, len(colors)):
c1 = colors[i]
c2 = colors[j]
if c1 < c2:
pairs.add((c1, c2))
else:
pairs.add((c2, c1))
for c in colors:
if len(colors) >= 1:
for c2 in colors:
if c != c2:
if c < c2:
pairs.add((c, c2))
else:
pairs.add((c2, c))
valid = 0
# Process each unique pair (c, d) in pairs, ensuring c < d
processed = set()
for (c1, c2) in pairs:
if c1 == c2:
continue
a, b = min(c1, c2), max(c1, c2)
if (a, b) in processed:
continue
processed.add((a, b))
combined_edges = []
if a in color_edges:
combined_edges.extend(color_edges[a])
if b in color_edges:
combined_edges.extend(color_edges[b])
parent = {}
size = {}
def find(u):
if parent[u] != u:
parent[u] = find(parent[u])
return parent[u]
def union(u, v):
u_root = find(u)
v_root = find(v)
if u_root == v_root:
return
if size[u_root] < size[v_root]:
u_root, v_root = v_root, u_root
parent[v_root] = u_root
size[u_root] += size[v_root]
for u, v in combined_edges:
if u not in parent:
parent[u] = u
size[u] = 1
if v not in parent:
parent[v] = v
size[v] = 1
union(u, v)
components = defaultdict(int)
for node in parent:
root = find(node)
components[root] = size.get(root, 1)
total_ab = sum(m * (m-1) // 2 for m in components.values())
sa = sc.get(a, 0)
sb = sc.get(b, 0)
valid += (total_ab - sa - sb)
print(valid)
if __name__ == '__main__':
main()
lam6er