結果
問題 | No.399 動的な領主 |
ユーザー | shotoyoo |
提出日時 | 2021-07-25 23:05:14 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 531 ms / 2,000 ms |
コード長 | 8,894 bytes |
コンパイル時間 | 164 ms |
コンパイル使用メモリ | 82,176 KB |
実行使用メモリ | 104,320 KB |
最終ジャッジ日時 | 2024-07-21 07:30:10 |
合計ジャッジ時間 | 7,389 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 46 ms
54,016 KB |
testcase_01 | AC | 44 ms
53,888 KB |
testcase_02 | AC | 48 ms
54,656 KB |
testcase_03 | AC | 48 ms
54,784 KB |
testcase_04 | AC | 105 ms
76,544 KB |
testcase_05 | AC | 164 ms
78,976 KB |
testcase_06 | AC | 508 ms
98,304 KB |
testcase_07 | AC | 528 ms
98,432 KB |
testcase_08 | AC | 530 ms
99,456 KB |
testcase_09 | AC | 504 ms
98,944 KB |
testcase_10 | AC | 112 ms
77,184 KB |
testcase_11 | AC | 147 ms
79,104 KB |
testcase_12 | AC | 469 ms
100,480 KB |
testcase_13 | AC | 469 ms
99,840 KB |
testcase_14 | AC | 217 ms
104,064 KB |
testcase_15 | AC | 221 ms
104,320 KB |
testcase_16 | AC | 250 ms
99,712 KB |
testcase_17 | AC | 487 ms
98,432 KB |
testcase_18 | AC | 531 ms
98,560 KB |
ソースコード
import sys input = lambda : sys.stdin.readline().rstrip() sys.setrecursionlimit(2*10**5+10) write = lambda x: sys.stdout.write(x+"\n") debug = lambda x: sys.stderr.write(x+"\n") writef = lambda x: print("{:.12f}".format(x)) from itertools import chain class HLD: def __init__(self, g, root=0, sg=None): """g: 隣接行列 root : 木の根 sg: 遅延セグ木 """ self.g = g self.n = len(g) self.parent = [-1]*self.n self.size = [1]*self.n self.head = [0]*self.n self.preorder = [0]*self.n self.k = 0 self.depth = [0]*self.n if sg is not None: self.sg = sg for v in chain(range(root, self.n), range(0, root)): if self.parent[v] == -1: self.dfs_pre(v) self.dfs_hld(v) def dfs_pre(self, v): g = self.g stack = [v] order = [v] while stack: v = stack.pop() for u in g[v]: if self.parent[v] == u: continue self.parent[u] = v self.depth[u] = self.depth[v]+1 stack.append(u) order.append(u) # 隣接リストの左端: heavyな頂点への辺 # 隣接リストの右端: 親への辺 while order: v = order.pop() child_v = g[v] if len(child_v) and child_v[0] == self.parent[v]: child_v[0], child_v[-1] = child_v[-1], child_v[0] for i, u in enumerate(child_v): if u == self.parent[v]: continue self.size[v] += self.size[u] if self.size[u] > self.size[child_v[0]]: child_v[i], child_v[0] = child_v[0], child_v[i] def dfs_hld(self, v): stack = [v] while stack: v = stack.pop() self.preorder[v] = self.k self.k += 1 top = self.g[v][0] # 隣接リストを逆順に見ていく(親 > lightな頂点への辺 > heavyな頂点 (top)) # 連結成分が連続するようにならべる for u in reversed(self.g[v]): if u == self.parent[v]: continue if u == top: self.head[u] = self.head[v] else: self.head[u] = u stack.append(u) def for_each(self, u, v): # [u, v]上の頂点集合の区間を列挙 while True: if self.preorder[u] > self.preorder[v]: u, v = v, u l = max(self.preorder[self.head[v]], self.preorder[u]) r = self.preorder[v] yield l, r # [l, r] if self.head[u] != self.head[v]: v = self.parent[self.head[v]] else: return def for_each_edge(self, u, v): # [u, v]上の辺集合の区間列挙 # 辺の情報は子の頂点に while True: if self.preorder[u] > self.preorder[v]: u, v = v, u if self.head[u] != self.head[v]: yield self.preorder[self.head[v]], self.preorder[v] v = self.parent[self.head[v]] else: if u != v: yield self.preorder[u]+1, self.preorder[v] break def subtree(self, v): # 頂点vの部分木の頂点集合の区間 [l, r) l = self.preorder[v] r = self.preorder[v]+self.size[v] return l, r def lca(self, u, v): # 頂点u, vのLCA while True: if self.preorder[u] > self.preorder[v]: u, v = v, u if self.head[u] == self.head[v]: return u v = self.parent[self.head[v]] def update(self, u, v, val): for l,r in self.for_each(u,v): self.sg.update(l,r+1,val) def query(self, u, v): res = ninf for l,r in self.for_each(u,v): res = op(res, self.sg.query(l,r+1)) return res def update_edge(self, u, v, val): for l,r in self.for_each_edge(u,v): self.sg.update(l,r+1,val) def query_edge(self, u, v): res = ninf for l,r in self.for_each_edge(u,v): res = op(res, self.sg.query(l,r+1)) return res class LSG: def __init__(self,n, a=None): self._n = n self._ninf = ninf x = 0 while (1 << x) < self._n: x += 1 self._log = x self._size = 1 << self._log self._d = [ninf] * (2 * self._size) self._lz = [f0] * self._size if a is not None: for i in range(self._n): self._d[self._size + i] = a[i] for i in range(self._size - 1, 0, -1): self._update(i) def check(self): return [self.query_point(p) for p in range(self._n)] def update_point(self, p, x): p += self._size for i in range(self._log, 0, -1): self._push(p >> i) self._d[p] = x for i in range(1, self._log + 1): self._update(p >> i) def query_point(self, p): p += self._size for i in range(self._log, 0, -1): self._push(p >> i) return self._d[p] def query(self, left, right): if left == right: return ninf left += self._size right += self._size for i in range(self._log, 0, -1): if ((left >> i) << i) != left: self._push(left >> i) if ((right >> i) << i) != right: self._push(right >> i) sml = ninf smr = ninf while left < right: if left & 1: sml = op(sml, self._d[left]) left += 1 if right & 1: right -= 1 smr = op(self._d[right], smr) left >>= 1 right >>= 1 return op(sml, smr) def query_all(self): return self._d[1] def update(self, left, right, f): if right is None: p = left p += self._size for i in range(self._log, 0, -1): self._push(p >> i) self._d[p] = mapping(f, self._d[p]) for i in range(1, self._log + 1): self._update(p >> i) else: if left == right: return left += self._size right += self._size for i in range(self._log, 0, -1): if ((left >> i) << i) != left: self._push(left >> i) if ((right >> i) << i) != right: self._push((right - 1) >> i) l2 = left r2 = right while left < right: if left & 1: self._all_apply(left, f) left += 1 if right & 1: right -= 1 self._all_apply(right, f) left >>= 1 right >>= 1 left = l2 right = r2 for i in range(1, self._log + 1): if ((left >> i) << i) != left: self._update(left >> i) if ((right >> i) << i) != right: self._update((right - 1) >> i) def _update(self, k): self._d[k] = op(self._d[2 * k], self._d[2 * k + 1]) def _all_apply(self, k, f) -> None: self._d[k] = mapping(f, self._d[k]) if k < self._size: self._lz[k] = composition(f, self._lz[k]) def _push(self, k): self._all_apply(2 * k, self._lz[k]) self._all_apply(2 * k + 1, self._lz[k]) self._lz[k] = f0 def loc(self, l, r): return self._lz[self._size+l : self._size+r] n = int(input()) ns = [[] for _ in range(n)] for _ in range(n-1): u,v = map(int, input().split()) u -= 1 v -= 1 ns[u].append(v) ns[v].append(u) h = HLD(ns) q = int(input()) ps0 = [0]*n ms0 = [0]*n for _ in range(q): u,v = map(int, input().split()) u -= 1 v -= 1 l = h.lca(u,v) ps0[u] += 1 ps0[v] += 1 ms0[l] += 1 vs = [0]*n q = [(0, -1)] ans = 0 ps = [0]*n ms = [0]*n while q: u,prv = q.pop() if u<0: # 返るときの処理 u = ~u p = ps0[u] m = ms0[u] for v in ns[u]: if v==prv: continue p += ps[v] m += ms[v] ps[u] = p ms[u] = m val = (ps[u] - 2*ms[u] + ms0[u])*(ps[u]-2*ms[u] + ms0[u] +1)//2 ans += val # print(u,val) else: q.append((~u,prv)) for v in ns[u]: # 進むときの処理 if v==prv: continue q.append((v,u)) print(ans)