結果
問題 | No.1221 木 *= 3 |
ユーザー |
|
提出日時 | 2021-04-10 17:11:43 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 477 ms / 2,000 ms |
コード長 | 11,168 bytes |
コンパイル時間 | 404 ms |
コンパイル使用メモリ | 82,196 KB |
実行使用メモリ | 132,220 KB |
最終ジャッジ日時 | 2024-06-26 06:13:01 |
合計ジャッジ時間 | 8,516 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 4 |
other | AC * 18 |
ソースコード
class Tree():def __init__(self, n, decrement=1):self.n = nself.edges = [[] for _ in range(n)]self._edge_label = [[] for _ in range(n)]self.root = Noneself.size = [1]*n # number of nodes in subtreeself.decrement = decrementdef add_edge(self, u, v, i):u, v = u-self.decrement, v-self.decrementself.edges[u].append(v)self.edges[v].append(u)self._edge_label[u].append((v,i))self._edge_label[v].append((u,i))def add_edges(self, edges):for i, p in enumerate(edges):u, v = pu, v = u-self.decrement, v-self.decrementself.edges[u].append(v)self.edges[v].append(u)self._edge_label[u].append((v, i))self._edge_label[v].append((u, i))def set_root(self, root):root -= self.decrementself.depth = [-1]*self.nself.root = rootself.par = [-1]*self.nself.depth[root] = 0self.edge_label = [-1]*self.nself.order = [root]next_set = [root]while next_set:p = next_set.pop()for q, i in self._edge_label[p]:if self.depth[q] != -1: continueself.par[q] = pself.depth[q] = self.depth[p]+1self.edge_label[q]=iself.order.append(q)next_set.append(q)for p in self.order[::-1]:for q in self.edges[p]:if self.par[p] == q: continueself.size[p] += self.size[q]def diameter(self, path=False):# assert self.root is not Noneu = self.depth.index(max(self.depth))dist = [-1]*self.ndist[u] = 0prev = [-1]*self.nnext_set = [u]while next_set:p = next_set.pop()for q in self.edges[p]:if dist[q] != -1: continuedist[q] = dist[p]+1prev[q] = pnext_set.append(q)d = max(dist)if path:v = w = dist.index(d)path = [v+1]while w != u:w = prev[w]path.append(w+self.decrement)return d, v+self.decrement, u+self.decrement, pathelse: return ddef rerooting(self, op, merge, id):# assert self.root is not Nonedp1 = [id] * self.ndp2 = [id] * self.nfor p in self.order[::-1]:t = idfor q in self.edges[p]:if self.par[p] == q: continuedp2[q] = tt = merge(t, op(dp1[q], p, q))t = idfor q in self.edges[p][::-1]:if self.par[p] == q: continuedp2[q] = merge(t, dp2[q])t = merge(t, op(dp1[q], p, q))dp1[p] = tfor q in self.order[1:]:pq = self.par[q]dp2[q] = op(merge(dp2[q], dp2[pq]), q, pq)dp1[q] = merge(dp1[q], dp2[q])return dp1def heavy_light_decomposition(self):"""return flat array of lists of heavy edges (1-indexed if decrement=True)"""# assert self.root is not Noneself.vid = [-1]*self.nself.hld = [-1]*self.nself.head = [-1]*self.nself.head[self.root] = self.rootself.heavy_node = [-1]*self.nnext_set = [self.root]for i in range(self.n):""" for tree graph, dfs ends in N times """p = next_set.pop()self.vid[p] = iself.hld[i] = p+self.decrementmaxs = 0for q in self.edges[p]:""" encode direction of Heavy edge into heavy_node """if self.par[p] == q: continueif maxs < self.size[q]:maxs = self.size[q]self.heavy_node[p] = qfor q in self.edges[p]:""" determine "head" of heavy edge """if self.par[p] == q or self.heavy_node[p] == q: continueself.head[q] = qnext_set.append(q)if self.heavy_node[p] != -1:self.head[self.heavy_node[p]] = self.head[p]next_set.append(self.heavy_node[p])return self.hlddef lca(self, u, v):# assert self.head is not Noneu, v = u-self.decrement, v-self.decrementwhile True:if self.vid[u] > self.vid[v]: u, v = v, uif self.head[u] != self.head[v]:v = self.par[self.head[v]]else:return u + self.decrementdef path(self, u, v):""" return the path array of the shortest path on u-v """p = self.lca(u, v)u, v, p = u-self.decrement, v-self.decrement, p-self.decrementR = []while u != p:yield u+self.decrementu = self.par[u]yield p+self.decrementwhile v != p:R.append(v)v = self.par[v]for v in reversed(R):yield v+self.decrementdef distance(self, u, v):# assert self.head is not Nonep = self.lca(u, v)u, v, p = u-self.decrement, v-self.decrement, p-self.decrementreturn self.depth[u] + self.depth[v] - 2*self.depth[p]def find(self, u, v, x):return self.distance(u,x)+self.distance(x,v)==self.distance(u,v)def path_to_list(self, u, v, edge_query=False):"""transform a half-open interval into segments on the self.hld, which is the heavy edge listedge_query: map from edge (par,chi) to point (chi)(note: The root is never updated)"""# assert self.head is not Noneu, v = u-self.decrement, v-self.decrementwhile True:if self.vid[u] > self.vid[v]: u, v = v, uif self.head[u] != self.head[v]:yield self.vid[self.head[v]], self.vid[v] + 1v = self.par[self.head[v]]else:yield self.vid[u] + edge_query, self.vid[v] + 1returndef ver_to_idx(self, u):""" return index on self.hld corresponding to vertex u """return self.vid[u-self.decrement]def idx_to_ver(self, i):""" from index i on self.hld to vertex-index """return self.hld[i]def idx_to_edge(self, i):""" from index i on self.hld to edge-index """return self.edge_label[self.hld[i]-self.decrement]def subtree_query(self, u):u -= self.decrementreturn self.vid[u], self.vid[u] + self.size[u]def top_down(self,dp):def merge(dp_chi,dp_par):return dp_chi^dp_parfor chi in self.order[1:]:par=self.par[chi]dp[chi]=merge(dp[chi], dp[par])return dpdef top_down_edge_query(self,dp):def merge(dp_chi,dp_par):return dp_chi^dp_parfor p in self.order[1+len(self.edges[self.root]):]:chi,par=self.edge_label[p],self.edge_label[self.par[p]]dp[chi]=merge(dp[chi], dp[par])return dpdef bottom_up(self,dp):for par in self.order[::-1]:for chi in self.edges[par]:if self.par[par] == chi: continuedp[1][par]+=max(dp[0][chi],dp[1][chi]+B[par]+B[chi])dp[0][par]+=max(dp[0][chi],dp[1][chi])return dpdef draw(self):import matplotlib.pyplot as pltimport networkx as nxG = nx.Graph()for x in range(self.n):for y in self.edges[x]:G.add_edge(x + self.decrement, y + self.decrement)pos = nx.spring_layout(G)nx.draw_networkx(G, pos)plt.axis("off")plt.show()#########################################################################################################def example_tree(N=10,show=True,decrement=1):global input# decrement=True: create 1-indexed treedef find(x):tmp=[]while parents[x]>=0:tmp.append(x)x=parents[x]for y in tmp: parents[y]=xreturn xdef union(x,y):x,y=find(x),find(y)if x==y: returnif parents[x]>parents[y]: x,y=y,xparents[x]+=parents[y]parents[y]=xdef same(x,y):return find(x)==find(y)import random# N = random.randint(2, N)parents=[-1]*Nedges=[]input_data=[]input_data.append(str(N))for i in range(N-1):while True:j=random.randint(0,N-1)if same(i,j): continueunion(i,j)edges.append((i+decrement,j+decrement))breakfor i,j in edges:input_data.append(" ".join(map(str,(i,j))))input_data=iter(input_data)input=lambda:next(input_data)if show:print("#######################")print(N)for p in edges:print(*p)print("#######################")return edgesdef example_special_tree(N, type, show=True, decrement=1):global inputedges = []if type=="path":for i in range(N-1):edges.append((i+decrement,i+1+decrement))if type=="star":for i in range(N-1):edges.append((i,i+1+decrement))if type=="binary":i = 1while True:if (i<<1) >= N: breakedges.append((i-1+decrement, (i<<1)-1+decrement))if (i<<1)+1 >= N: breakedges.append((i-1+decrement, (i<<1)+decrement))i += 1input_data=[]input_data.append(str(N))for i,j in edges:input_data.append(" ".join(map(str,(i,j))))input_data=iter(input_data)input=lambda:next(input_data)if show:print("#######################")print(N)for p in edges:print(*p)print("#######################")return edgesdef draw(edges, decrement=1):import matplotlib.pyplot as pltimport networkx as nxN = len(edges)+1G = nx.Graph()for x in range(N):for y in edges[x]:G.add_edge(x + decrement, y + decrement)pos = nx.spring_layout(G)nx.draw_networkx(G, pos)plt.axis("off")plt.show()def example():global inputexample = iter("""5 21 52 33 54 31 2 31 4 2""".strip().split("\n"))input = lambda: next(example)#########################################################################################################import sysinput = sys.stdin.readline# example_tree(N=10, decrement=True)# example_special_tree(N=10, type="star", decrement=True)# N=int(input())N=int(input())T = Tree(N,decrement=1)A=list(map(int, input().split()))B=list(map(int, input().split()))for i in range(N-1):x, y = map(int, input().split())T.add_edge(x,y,i)T.set_root(1)dp=[[0]*N for _ in range(2)]for i,a in enumerate(A):dp[0][i]+=aT.bottom_up(dp)print(max(dp[0][0],dp[1][0]))