結果
問題 | No.2892 Lime and Karin |
ユーザー |
|
提出日時 | 2024-09-13 21:52:30 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 2,566 ms / 8,000 ms |
コード長 | 4,503 bytes |
コンパイル時間 | 267 ms |
コンパイル使用メモリ | 81,792 KB |
実行使用メモリ | 171,384 KB |
最終ジャッジ日時 | 2024-09-13 21:54:06 |
合計ジャッジ時間 | 65,692 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 52 |
ソースコード
import sysfrom itertools import permutationsfrom heapq import heappop,heappushfrom collections import dequeimport randomimport bisectinput = lambda :sys.stdin.readline().rstrip()mi = lambda :map(int,input().split())li = lambda :list(mi())def solve(N,edge,S):centroid_done = [False] * Ntmp_parent = [None] * Ntmp_sz = [None] * Ndef find_centroid(root):deq = deque([root])topo = []while deq:v = deq.popleft()topo.append(v)for nv in edge[v]:if centroid_done[nv]:continueif nv == tmp_parent[v]:continuetmp_parent[nv] = vdeq.append(nv)tmp_n = len(topo)centroid = -1for v in topo[::-1]:tmp_sz[v] = 1centroid_flg = Truefor nv in edge[v]:if centroid_done[nv]:continueif nv == tmp_parent[v]:continueif tmp_sz[nv]*2 > tmp_n:centroid_flg = Falsetmp_sz[v] += tmp_sz[nv]if 2*(tmp_n-tmp_sz[v]) > tmp_n:centroid_flg = Falseif centroid_flg:centroid = vfor v in topo:tmp_sz[v] = Nonetmp_parent[v] = Nonereturn centroiddef f(n,A,x):"""A[i]+A[j]+x > 0 を満たすi,jの数を求める-n <= A[i] <= n が保証される"""tmp_cnt = [0] * (2*n+1)for a in A:tmp_cnt[a+n] += 1for i in range(2*n)[::-1]:tmp_cnt[i] += tmp_cnt[i+1]res = 0for a in A:lower = max((-a-x+1) + n,0)if lower <= 2*n:res += tmp_cnt[lower]if a+a+x > 0:res -= 1return res//2tmp_dep = [None] * Ntmp_centroid_child_parent = [None] * Ntmp_centroid_child_group = [[] for v in range(N)]def calc_sub(centroid):deq = deque([centroid])tmp_dep[centroid] = 0topo = []while deq:v = deq.popleft()topo.append(v)for nv in edge[v]:if centroid_done[nv]:continueif nv == tmp_parent[v]:continuetmp_parent[nv] = vif v == centroid:tmp_centroid_child_parent[nv] = nvelse:tmp_centroid_child_parent[nv] = tmp_centroid_child_parent[v]tmp_centroid_child_group[tmp_centroid_child_parent[nv]].append(nv)tmp_dep[nv] = tmp_dep[v] + 2 * S[nv] - 1deq.append(nv)tmp_n = len(topo)A = [tmp_dep[v] for v in topo if v!=centroid]res = f(tmp_n,A,2 * S[centroid] - 1)for v in topo:if v != centroid and tmp_dep[v] + 2 * S[centroid] - 1 > 0:res += 1for centroid_child in edge[centroid]:if centroid_done[centroid_child]:continuetmp_n = len(tmp_centroid_child_group[centroid_child])A = [tmp_dep[v] for v in tmp_centroid_child_group[centroid_child]]res -= f(tmp_n,A,2 * S[centroid] - 1)for v in topo:tmp_parent[v] = Nonetmp_sz[v] = Nonetmp_dep[v] = Nonetmp_centroid_child_group[v] = []tmp_centroid_child_parent[v] = Nonereturn resdef centroid_decomp():res = 0root_deq = deque([0])while root_deq:root = root_deq.popleft()centroid = find_centroid(root)res += calc_sub(centroid)#print(centroid,calc_sub(centroid))centroid_done[centroid] = Truefor centroid_child in edge[centroid]:if not centroid_done[centroid_child]:root_deq.append(centroid_child)return resreturn centroid_decomp() + sum(S)N = int(input())edge = [[] for v in range(N)]for _ in range(N-1):u,v = mi()edge[u-1].append(v-1)edge[v-1].append(u-1)S = [int(c) for c in input()]print(solve(N,edge,S))