結果
問題 | No.1221 木 *= 3 |
ユーザー | None |
提出日時 | 2021-04-10 17:11:43 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 477 ms / 2,000 ms |
コード長 | 11,168 bytes |
コンパイル時間 | 404 ms |
コンパイル使用メモリ | 82,196 KB |
実行使用メモリ | 132,220 KB |
最終ジャッジ日時 | 2024-06-26 06:13:01 |
合計ジャッジ時間 | 8,516 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge1 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 45 ms
54,400 KB |
testcase_01 | AC | 45 ms
54,784 KB |
testcase_02 | AC | 43 ms
54,272 KB |
testcase_03 | AC | 42 ms
54,400 KB |
testcase_04 | AC | 43 ms
54,400 KB |
testcase_05 | AC | 43 ms
54,272 KB |
testcase_06 | AC | 42 ms
54,528 KB |
testcase_07 | AC | 334 ms
119,432 KB |
testcase_08 | AC | 322 ms
119,356 KB |
testcase_09 | AC | 354 ms
120,100 KB |
testcase_10 | AC | 343 ms
119,576 KB |
testcase_11 | AC | 336 ms
119,692 KB |
testcase_12 | AC | 313 ms
130,260 KB |
testcase_13 | AC | 314 ms
132,220 KB |
testcase_14 | AC | 321 ms
130,592 KB |
testcase_15 | AC | 318 ms
130,760 KB |
testcase_16 | AC | 318 ms
130,764 KB |
testcase_17 | AC | 449 ms
120,776 KB |
testcase_18 | AC | 467 ms
120,568 KB |
testcase_19 | AC | 421 ms
122,784 KB |
testcase_20 | AC | 465 ms
120,556 KB |
testcase_21 | AC | 477 ms
120,400 KB |
ソースコード
class Tree(): def __init__(self, n, decrement=1): self.n = n self.edges = [[] for _ in range(n)] self._edge_label = [[] for _ in range(n)] self.root = None self.size = [1]*n # number of nodes in subtree self.decrement = decrement def add_edge(self, u, v, i): u, v = u-self.decrement, v-self.decrement self.edges[u].append(v) self.edges[v].append(u) self._edge_label[u].append((v,i)) self._edge_label[v].append((u,i)) def add_edges(self, edges): for i, p in enumerate(edges): u, v = p u, v = u-self.decrement, v-self.decrement self.edges[u].append(v) self.edges[v].append(u) self._edge_label[u].append((v, i)) self._edge_label[v].append((u, i)) def set_root(self, root): root -= self.decrement self.depth = [-1]*self.n self.root = root self.par = [-1]*self.n self.depth[root] = 0 self.edge_label = [-1]*self.n self.order = [root] next_set = [root] while next_set: p = next_set.pop() for q, i in self._edge_label[p]: if self.depth[q] != -1: continue self.par[q] = p self.depth[q] = self.depth[p]+1 self.edge_label[q]=i self.order.append(q) next_set.append(q) for p in self.order[::-1]: for q in self.edges[p]: if self.par[p] == q: continue self.size[p] += self.size[q] def diameter(self, path=False): # assert self.root is not None u = self.depth.index(max(self.depth)) dist = [-1]*self.n dist[u] = 0 prev = [-1]*self.n next_set = [u] while next_set: p = next_set.pop() for q in self.edges[p]: if dist[q] != -1: continue dist[q] = dist[p]+1 prev[q] = p next_set.append(q) d = max(dist) if path: v = w = dist.index(d) path = [v+1] while w != u: w = prev[w] path.append(w+self.decrement) return d, v+self.decrement, u+self.decrement, path else: return d def rerooting(self, op, merge, id): # assert self.root is not None dp1 = [id] * self.n dp2 = [id] * self.n for p in self.order[::-1]: t = id for q in self.edges[p]: if self.par[p] == q: continue dp2[q] = t t = merge(t, op(dp1[q], p, q)) t = id for q in self.edges[p][::-1]: if self.par[p] == q: continue dp2[q] = merge(t, dp2[q]) t = merge(t, op(dp1[q], p, q)) dp1[p] = t for q in self.order[1:]: pq = self.par[q] dp2[q] = op(merge(dp2[q], dp2[pq]), q, pq) dp1[q] = merge(dp1[q], dp2[q]) return dp1 def heavy_light_decomposition(self): """ return flat array of lists of heavy edges (1-indexed if decrement=True) """ # assert self.root is not None self.vid = [-1]*self.n self.hld = [-1]*self.n self.head = [-1]*self.n self.head[self.root] = self.root self.heavy_node = [-1]*self.n next_set = [self.root] for i in range(self.n): """ for tree graph, dfs ends in N times """ p = next_set.pop() self.vid[p] = i self.hld[i] = p+self.decrement maxs = 0 for q in self.edges[p]: """ encode direction of Heavy edge into heavy_node """ if self.par[p] == q: continue if maxs < self.size[q]: maxs = self.size[q] self.heavy_node[p] = q for q in self.edges[p]: """ determine "head" of heavy edge """ if self.par[p] == q or self.heavy_node[p] == q: continue self.head[q] = q next_set.append(q) if self.heavy_node[p] != -1: self.head[self.heavy_node[p]] = self.head[p] next_set.append(self.heavy_node[p]) return self.hld def lca(self, u, v): # assert self.head is not None u, v = u-self.decrement, v-self.decrement while True: if self.vid[u] > self.vid[v]: u, v = v, u if self.head[u] != self.head[v]: v = self.par[self.head[v]] else: return u + self.decrement def path(self, u, v): """ return the path array of the shortest path on u-v """ p = self.lca(u, v) u, v, p = u-self.decrement, v-self.decrement, p-self.decrement R = [] while u != p: yield u+self.decrement u = self.par[u] yield p+self.decrement while v != p: R.append(v) v = self.par[v] for v in reversed(R): yield v+self.decrement def distance(self, u, v): # assert self.head is not None p = self.lca(u, v) u, v, p = u-self.decrement, v-self.decrement, p-self.decrement return self.depth[u] + self.depth[v] - 2*self.depth[p] def find(self, u, v, x): return self.distance(u,x)+self.distance(x,v)==self.distance(u,v) def path_to_list(self, u, v, edge_query=False): """ transform a half-open interval into segments on the self.hld, which is the heavy edge list edge_query: map from edge (par,chi) to point (chi) (note: The root is never updated) """ # assert self.head is not None u, v = u-self.decrement, v-self.decrement while True: if self.vid[u] > self.vid[v]: u, v = v, u if self.head[u] != self.head[v]: yield self.vid[self.head[v]], self.vid[v] + 1 v = self.par[self.head[v]] else: yield self.vid[u] + edge_query, self.vid[v] + 1 return def ver_to_idx(self, u): """ return index on self.hld corresponding to vertex u """ return self.vid[u-self.decrement] def idx_to_ver(self, i): """ from index i on self.hld to vertex-index """ return self.hld[i] def idx_to_edge(self, i): """ from index i on self.hld to edge-index """ return self.edge_label[self.hld[i]-self.decrement] def subtree_query(self, u): u -= self.decrement return self.vid[u], self.vid[u] + self.size[u] def top_down(self,dp): def merge(dp_chi,dp_par): return dp_chi^dp_par for chi in self.order[1:]: par=self.par[chi] dp[chi]=merge(dp[chi], dp[par]) return dp def top_down_edge_query(self,dp): def merge(dp_chi,dp_par): return dp_chi^dp_par for p in self.order[1+len(self.edges[self.root]):]: chi,par=self.edge_label[p],self.edge_label[self.par[p]] dp[chi]=merge(dp[chi], dp[par]) return dp def bottom_up(self,dp): for par in self.order[::-1]: for chi in self.edges[par]: if self.par[par] == chi: continue dp[1][par]+=max(dp[0][chi],dp[1][chi]+B[par]+B[chi]) dp[0][par]+=max(dp[0][chi],dp[1][chi]) return dp def draw(self): import matplotlib.pyplot as plt import networkx as nx G = nx.Graph() for x in range(self.n): for y in self.edges[x]: G.add_edge(x + self.decrement, y + self.decrement) pos = nx.spring_layout(G) nx.draw_networkx(G, pos) plt.axis("off") plt.show() ######################################################################################################### def example_tree(N=10,show=True,decrement=1): global input # decrement=True: create 1-indexed tree def find(x): tmp=[] while parents[x]>=0: tmp.append(x) x=parents[x] for y in tmp: parents[y]=x return x def union(x,y): x,y=find(x),find(y) if x==y: return if parents[x]>parents[y]: x,y=y,x parents[x]+=parents[y] parents[y]=x def same(x,y): return find(x)==find(y) import random # N = random.randint(2, N) parents=[-1]*N edges=[] input_data=[] input_data.append(str(N)) for i in range(N-1): while True: j=random.randint(0,N-1) if same(i,j): continue union(i,j) edges.append((i+decrement,j+decrement)) break for i,j in edges: input_data.append(" ".join(map(str,(i,j)))) input_data=iter(input_data) input=lambda:next(input_data) if show: print("#######################") print(N) for p in edges: print(*p) print("#######################") return edges def example_special_tree(N, type, show=True, decrement=1): global input edges = [] if type=="path": for i in range(N-1): edges.append((i+decrement,i+1+decrement)) if type=="star": for i in range(N-1): edges.append((i,i+1+decrement)) if type=="binary": i = 1 while True: if (i<<1) >= N: break edges.append((i-1+decrement, (i<<1)-1+decrement)) if (i<<1)+1 >= N: break edges.append((i-1+decrement, (i<<1)+decrement)) i += 1 input_data=[] input_data.append(str(N)) for i,j in edges: input_data.append(" ".join(map(str,(i,j)))) input_data=iter(input_data) input=lambda:next(input_data) if show: print("#######################") print(N) for p in edges: print(*p) print("#######################") return edges def draw(edges, decrement=1): import matplotlib.pyplot as plt import networkx as nx N = len(edges)+1 G = nx.Graph() for x in range(N): for y in edges[x]: G.add_edge(x + decrement, y + decrement) pos = nx.spring_layout(G) nx.draw_networkx(G, pos) plt.axis("off") plt.show() def example(): global input example = iter( """ 5 2 1 5 2 3 3 5 4 3 1 2 3 1 4 2 """ .strip().split("\n")) input = lambda: next(example) ######################################################################################################### import sys input = sys.stdin.readline # example_tree(N=10, decrement=True) # example_special_tree(N=10, type="star", decrement=True) # N=int(input()) N=int(input()) T = Tree(N,decrement=1) A=list(map(int, input().split())) B=list(map(int, input().split())) for i in range(N-1): x, y = map(int, input().split()) T.add_edge(x,y,i) T.set_root(1) dp=[[0]*N for _ in range(2)] for i,a in enumerate(A): dp[0][i]+=a T.bottom_up(dp) print(max(dp[0][0],dp[1][0]))