結果

問題 No.1718 Random Squirrel
ユーザー lam6er
提出日時 2025-03-31 17:45:10
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 771 ms / 2,000 ms
コード長 4,339 bytes
コンパイル時間 365 ms
コンパイル使用メモリ 82,720 KB
実行使用メモリ 141,332 KB
最終ジャッジ日時 2025-03-31 17:46:14
合計ジャッジ時間 12,167 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 31
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque

def main():
    input = sys.stdin.read
    data = input().split()
    idx = 0
    N, K = int(data[idx]), int(data[idx+1])
    idx +=2
    
    edges = [[] for _ in range(N+1)]
    for _ in range(N-1):
        u = int(data[idx])
        v = int(data[idx+1])
        edges[u].append(v)
        edges[v].append(u)
        idx +=2
    
    D = list(map(int, data[idx:idx+K]))
    idx +=K
    
    is_D = [False]*(N+1)
    for d in D:
        is_D[d] = True
    
    root = 1
    parent = [0]*(N+1)
    children = [[] for _ in range(N+1)]
    visited = [False]*(N+1)
    q = deque([root])
    visited[root] = True
    
    while q:
        u = q.popleft()
        for v in edges[u]:
            if not visited[v] and v != parent[u]:
                parent[v] = u
                children[u].append(v)
                visited[v] = True
                q.append(v)
    
    cnt = [0]*(N+1)
    stack = [(root, False)]
    while stack:
        u, processed = stack.pop()
        if processed:
            total = 0
            if is_D[u]:
                total = 1
            for v in children[u]:
                total += cnt[v]
            cnt[u] = total
        else:
            stack.append((u, True))
            for v in reversed(children[u]):
                stack.append((v, False))
    
    in_T = [False]*(N+1)
    for u in range(1, N+1):
        if is_D[u]:
            in_T[u] = True
        else:
            sum_children = sum(cnt[v] for v in children[u])
            others = K - sum_children
            if sum_children > 0 and others > 0:
                in_T[u] = True
            else:
                cnt_pos = 0
                for v in children[u]:
                    if cnt[v] > 0:
                        cnt_pos += 1
                        if cnt_pos >= 2:
                            break
                if cnt_pos >= 2:
                    in_T[u] = True
    
    sum_edges_T = 0
    adj_T = [[] for _ in range(N+1)]
    for u in range(1, N+1):
        for v in edges[u]:
            if v > u and in_T[u] and in_T[v]:
                sum_edges_T += 1
                adj_T[u].append(v)
                adj_T[v].append(u)
    
    t_nodes = [u for u in range(1, N+1) if in_T[u]]
    if not t_nodes:
        print('\n'.join('0' for _ in range(N)))
        return
    
    initial_node = t_nodes[0]
    
    def bfs_farthest(start, adj):
        dist = [-1]*(N+1)
        q = deque([start])
        dist[start] = 0
        max_dist = 0
        far_node = start
        while q:
            u = q.popleft()
            for v in adj[u]:
                if dist[v] == -1:
                    dist[v] = dist[u] + 1
                    q.append(v)
                    if in_T[v] and dist[v] > max_dist:
                        max_dist = dist[v]
                        far_node = v
        return far_node, max_dist
    
    u, _ = bfs_farthest(initial_node, adj_T)
    v, diam_len = bfs_farthest(u, adj_T)
    
    def compute_dist(start, adj):
        dist = [-1]*(N+1)
        q = deque([start])
        dist[start] = 0
        while q:
            u = q.popleft()
            for v_node in adj[u]:
                if dist[v_node] == -1:
                    dist[v_node] = dist[u] + 1
                    q.append(v_node)
        return dist
    
    dist_u = compute_dist(u, adj_T)
    dist_v = compute_dist(v, adj_T)
    
    d_to_T = [-1]*(N+1)
    Y_of = [-1]*(N+1)
    q = deque()
    for x in range(1, N+1):
        if in_T[x]:
            d_to_T[x] = 0
            Y_of[x] = x
            q.append(x)
    
    visited = [False]*(N+1)
    for x in q:
        visited[x] = True
    
    while q:
        u_node = q.popleft()
        for v_node in edges[u_node]:
            if not visited[v_node]:
                visited[v_node] = True
                d_to_T[v_node] = d_to_T[u_node] + 1
                Y_of[v_node] = Y_of[u_node]
                q.append(v_node)
    
    output = []
    for x in range(1, N+1):
        Y = Y_of[x]
        d = d_to_T[x]
        if Y == -1:
            output.append(0)
        else:
            du = dist_u[Y]
            dv = dist_v[Y]
            max_dist = max(du, dv)
            ans = d + 2 * sum_edges_T - max_dist
            output.append(ans)
    
    print('\n'.join(map(str, output)))

if __name__ == '__main__':
    main()
0