class segtree(): n=1 size=1 log=2 d=[0] op=None e=10**15 def __init__(self,V,OP,E): self.n=len(V) self.op=OP self.e=E self.log=(self.n-1).bit_length() self.size=1<>i) def get(self,p): assert 0<=p and p>=1 r>>=1 return self.op(sml,smr) def all_prod(self): return self.d[1] def max_right(self,l,f): assert 0<=l and l<=self.n assert f(self.e) if l==self.n: return self.n l+=self.size sm=self.e while(1): while(l%2==0): l>>=1 if not(f(self.op(sm,self.d[l]))): while(l1 & (r%2)): r>>=1 if not(f(self.op(self.d[r],sm))): while(r 0: C[P[i]].append(i) m = [10 ** 9] * N M = [-1] * N et = [] from collections import deque st = deque([0]) i = 0 while len(st): v = st.pop() if v < 0: v = ~v if et[-1] == v: continue et.append(v) m[v] = min(m[v], i) M[v] = max(m[v], i) i += 1 else: et.append(v) m[v] = min(m[v], i) M[v] = max(m[v], i) i += 1 for c in C[v]: st.append(~v) st.append(~c) st.append(c) V = [i for i in range(N)] V.sort(key = lambda x: -A[x]) seg = segtree([0] * len(et), lambda x, y: max(x, y), 0) dp = [0] * N for v in V: for c in C[v]: dp[v] += seg.prod(m[c], M[c] + 1) dp[v] += B[v] seg.set(m[v], dp[v]) ans = 0 depth = [-1] * N qu = deque([0]) while len(qu): v = qu.popleft() for c in C[v]: qu.append(c) depth[c] = depth[v] + 1 V.sort(key = lambda x: -depth[x]) for v in V: tmp = 0 for c in C[v]: tmp += seg.prod(m[c], M[c] + 1) dp[v] = max(dp[v], tmp) seg.set(m[v], dp[v]) print(max(dp))