結果
| 問題 |
No.399 動的な領主
|
| コンテスト | |
| ユーザー |
tktk_snsn
|
| 提出日時 | 2021-11-03 13:13:40 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 418 ms / 2,000 ms |
| コード長 | 3,581 bytes |
| コンパイル時間 | 548 ms |
| コンパイル使用メモリ | 82,304 KB |
| 実行使用メモリ | 123,052 KB |
| 最終ジャッジ日時 | 2024-10-13 02:30:04 |
| 合計ジャッジ時間 | 7,275 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 19 |
ソースコード
import sys
input = sys.stdin.buffer.readline
sys.setrecursionlimit(10 ** 7)
class SegTree(object):
def __init__(self, N, op_data, u_data):
self._n = N
self.log = (N-1).bit_length()
self.size = 1 << self.log
self.op = op_data
self.e = u_data
self.data = [u_data] * (2 * self.size)
# self.len = [1] * (2 * self.size)
def _update(self, i):
self.data[i] = self.op(self.data[i << 1], self.data[i << 1 | 1])
def initialize(self, arr=None):
""" segtreeをarrで初期化する。len(arr) == Nにすること """
if arr:
for i, a in enumerate(arr, self.size):
self.data[i] = a
for i in reversed(range(1, self.size)):
self._update(i)
# self.len[i] = self.len[i << 1] + self.len[i << 1 | 1]
def update(self, p, x):
""" data[p] = x とする (0-indexed)"""
p += self.size
self.data[p] = x
for i in range(1, self.log + 1):
self._update(p >> i)
def get(self, p):
""" data[p]を返す """
return self.data[p + self.size]
def prod(self, l, r):
"""
op_data(data[l], data[l+1], ..., data[r-1])を返す (0-indexed)
"""
sml = self.e
smr = self.e
l += self.size
r += self.size
while l < r:
if l & 1:
sml = self.op(sml, self.data[l])
l += 1
if r & 1:
r -= 1
smr = self.op(self.data[r], smr)
l >>= 1
r >>= 1
return self.op(sml, smr)
def all_prod(self):
""" op(data[0], data[1], ... data[N-1])を返す """
return self.data[1]
class LowestCommonAncestor(SegTree):
def __init__(self, N, root, G):
self.n = N
self.depth = [0] * N
self.tout = [-1] * N
self.tin = [-1] * N
self.tin[root] = 0
euler = [0]
par = [-1] * N
itr = [0] * N
que = [root]
topo = []
while que:
s = que[-1]
if itr[s] < len(G[s]):
t = G[s][itr[s]]
itr[s] += 1
if t == par[s]:
continue
par[t] = s
self.depth[t] = self.depth[s] + 1
self.tin[t] = len(euler)
euler.append(N * self.depth[t] + t)
que.append(t)
else:
topo.append(s)
p = par[s]
self.tout[s] = len(euler)
euler.append(N * self.depth[p] + p)
que.pop()
euler.pop()
self.par = par
self.topo = topo
super().__init__(len(euler), min, N * N + 10)
self.initialize(euler)
def __call__(self, a, b):
"""LCA(a, b)を返す"""
l = min(self.tin[a], self.tin[b])
r = max(self.tout[a], self.tout[b])
return self.prod(l, r) % self.n
N = int(input())
G = [[] for _ in range(N)]
for _ in range(N - 1):
a, b = map(lambda x: int(x) - 1, input().split())
G[a].append(b)
G[b].append(a)
"""各頂点を通った回数がわかるとOK"""
LCA = LowestCommonAncestor(N, 0, G)
par = LCA.par
cnt = [0] * N
Q = int(input())
for _ in range(Q):
x, y = map(lambda x: int(x) - 1, input().split())
z = LCA(x, y)
cnt[x] += 1
cnt[y] += 1
cnt[z] -= 1
if z != 0:
cnt[par[z]] -= 1
for s in LCA.topo[:-1]:
cnt[par[s]] += cnt[s]
ans = 0
for c in cnt:
ans += c * (c + 1) // 2
print(ans)
tktk_snsn