import sys def main(): N = int(sys.stdin.readline()) adj = [[] for _ in range(N)] for _ in range(N - 1): u, v = map(int, sys.stdin.readline().split()) adj[u].append(v) adj[v].append(u) fairies = list(map(int, sys.stdin.readline().split())) Q_count = int(sys.stdin.readline()) output_lines = [] for _ in range(Q_count): x_query = int(sys.stdin.readline()) nodes_to_affect = set() # Add x_query (distance 0) nodes_to_affect.add(x_query) # Add direct neighbors of x_query (distance 1) # Also, for each direct neighbor, add its neighbors (these form distance 2 candidates) for neighbor_of_x in adj[x_query]: nodes_to_affect.add(neighbor_of_x) # Add neighbors of this neighbor_of_x for neighbor_of_neighbor in adj[neighbor_of_x]: nodes_to_affect.add(neighbor_of_neighbor) total_fairies_at_x_query_after_move = 0 for node_v in nodes_to_affect: total_fairies_at_x_query_after_move += fairies[node_v] fairies[node_v] = 0 fairies[x_query] = total_fairies_at_x_query_after_move output_lines.append(str(fairies[x_query])) sys.stdout.write("\n".join(output_lines) + "\n") if __name__ == '__main__': main()