結果

問題 No.1002 Twotone
ユーザー gew1fw
提出日時 2025-06-12 19:52:00
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,722 bytes
コンパイル時間 168 ms
コンパイル使用メモリ 81,956 KB
実行使用メモリ 347,956 KB
最終ジャッジ日時 2025-06-12 19:52:35
合計ジャッジ時間 7,809 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 1 TLE * 1 -- * 31
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque

def main():
    sys.setrecursionlimit(1 << 25)
    N, K = map(int, sys.stdin.readline().split())
    edges = []
    adj = [[] for _ in range(N+1)]
    for i in range(N-1):
        u, v, c = map(int, sys.stdin.readline().split())
        edges.append((u, v, c))
        adj[u].append((v, c, i))
        adj[v].append((u, c, i))
    
    color_pairs = set()
    for idx in range(N-1):
        u, v, c = edges[idx]
        for node in [u, v]:
            for neighbor, _, other_idx in adj[node]:
                if other_idx != idx:
                    d = edges[other_idx][2]
                    pair = tuple(sorted((c, d)))
                    color_pairs.add(pair)
    
    result = 0
    for c, d in color_pairs:
        sub_adj = [[] for _ in range(N+1)]
        for idx in range(N-1):
            u, v, curr_c = edges[idx]
            if curr_c == c or curr_c == d:
                sub_adj[u].append(v)
                sub_adj[v].append(u)
        
        visited = [False] * (N+1)
        for u in range(1, N+1):
            if not visited[u]:
                q = deque([u])
                visited[u] = True
                n = 0
                component = []
                while q:
                    node = q.popleft()
                    component.append(node)
                    n += 1
                    for v in sub_adj[node]:
                        if not visited[v]:
                            visited[v] = True
                            q.append(v)
                
                total_paths = n * (n-1) // 2
                if total_paths == 0:
                    continue
                
                sub_c_adj = [[] for _ in range(N+1)]
                for idx in range(N-1):
                    u, v, curr_c = edges[idx]
                    if curr_c == c:
                        sub_c_adj[u].append(v)
                        sub_c_adj[v].append(u)
                
                visited_c = [False] * (N+1)
                sc = 0
                for node in component:
                    if not visited_c[node]:
                        q = deque([node])
                        visited_c[node] = True
                        m = 0
                        while q:
                            u = q.popleft()
                            m += 1
                            for v in sub_c_adj[u]:
                                if v in component and not visited_c[v]:
                                    visited_c[v] = True
                                    q.append(v)
                        sc += m * (m-1) // 2
                
                sub_d_adj = [[] for _ in range(N+1)]
                for idx in range(N-1):
                    u, v, curr_c = edges[idx]
                    if curr_c == d:
                        sub_d_adj[u].append(v)
                        sub_d_adj[v].append(u)
                
                visited_d = [False] * (N+1)
                sd = 0
                for node in component:
                    if not visited_d[node]:
                        q = deque([node])
                        visited_d[node] = True
                        m = 0
                        while q:
                            u = q.popleft()
                            m += 1
                            for v in sub_d_adj[u]:
                                if v in component and not visited_d[v]:
                                    visited_d[v] = True
                                    q.append(v)
                        sd += m * (m-1) // 2
                
                current = max(0, total_paths - sc - sd)
                result += current
    
    print(result)

if __name__ == '__main__':
    main()
0