結果
| 問題 |
No.901 K-ary εxtrεεmε
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2024-11-04 21:23:11 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 2,020 ms / 3,000 ms |
| コード長 | 11,050 bytes |
| コンパイル時間 | 212 ms |
| コンパイル使用メモリ | 82,392 KB |
| 実行使用メモリ | 228,796 KB |
| 最終ジャッジ日時 | 2024-11-04 21:24:26 |
| 合計ジャッジ時間 | 45,089 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 29 |
ソースコード
import sys
from collections import *
from functools import *
sys.setrecursionlimit(10**6)
class SegTree:
"""
Segment Tree
"""
def __init__(self, init_val, segfunc, ide_ele):
"""
初期化
init_val: 配列の初期値
"""
n = len(init_val)
self.segfunc = segfunc
self.ide_ele = ide_ele
self.num = 1 << (n - 1).bit_length()
self.tree = [ide_ele] * 2 * self.num
# 配列の値を葉にセット
for i in range(n):
self.tree[self.num + i] = init_val[i]
# 構築していく
for i in range(self.num - 1, 0, -1):
self.tree[i] = segfunc(self.tree[2 * i], self.tree[2 * i + 1])
def update(self, k, x):
"""
k番目の値をxに更新
k: index(0-index)
x: update value
"""
k += self.num
self.tree[k] = x
while k > 1:
self.tree[k >> 1] = self.segfunc(self.tree[k], self.tree[k ^ 1])
k >>= 1
def query(self, l, r):
"""
[l, r)のsegfuncしたものを得る
l: index(0-index)
r: index(0-index)
"""
res = self.ide_ele
l += self.num
r += self.num
while l < r:
if l & 1:
res = self.segfunc(res, self.tree[l])
l += 1
if r & 1:
res = self.segfunc(res, self.tree[r - 1])
l >>= 1
r >>= 1
return res
class HLD:
### HL分解をしてIDを振りなおしたものに対して、パスに含まれる区間を返す
### SegTreeにのせる配列はIDを並び替えたもの
def __init__(self, e, root=0):
self.N = len(e)
self.e = e
par = [-1] * self.N
sub = [-1] * self.N
self.root = root
dist = [-1] * self.N
v = deque()
dist[root] = 0
v.append(root)
while v:
x = v.popleft()
for ix in e[x]:
if dist[ix] != -1:
continue
dist[ix] = dist[x] + 1
v.append(ix)
H = [(-dist[i], i) for i in range(self.N)]
H.sort()
for h, i in H:
tmp = 1
for ix in e[i]:
if sub[ix] == -1:
par[i] = ix
else:
tmp += sub[ix]
sub[i] = tmp
self.ID = [-1] * self.N
self.ID[self.root] = 0
self.HEAD = [-1] * self.N
head = [-1] * self.N
self.PAR = [-1] * self.N
visited = [False] * self.N
self.HEAD[0] = 0
head[self.root] = 0
depth = [-1] * self.N
depth[self.root] = 0
self.DEPTH = [-1] * self.N
self.DEPTH[0] = 0
cnt = 0
v = deque([self.root])
self.SUB = [0] * self.N
self.SUB[0] = self.N
while v:
x = v.popleft()
visited[x] = True
self.ID[x] = cnt
cnt += 1
n = len(self.e[x])
tmp = [(sub[ix], ix) for ix in self.e[x]]
tmp.sort()
flg = 0
if x == self.root:
flg -= 1
for _, ix in tmp:
flg += 1
if visited[ix]:
continue
v.appendleft(ix)
if flg == n - 1:
head[ix] = head[x]
depth[ix] = depth[x]
else:
head[ix] = ix
depth[ix] = depth[x] + 1
for i in range(self.N):
self.PAR[self.ID[i]] = self.ID[par[i]]
self.HEAD[self.ID[i]] = self.ID[head[i]]
self.DEPTH[self.ID[i]] = depth[i]
self.SUB[self.ID[i]] = sub[i]
def path_query(self, l, r):
L = self.ID[l]
R = self.ID[r]
res = []
if self.DEPTH[L] < self.DEPTH[R]:
L, R = R, L
while self.DEPTH[L] != self.DEPTH[R]:
tmp = (self.HEAD[L], L + 1)
res.append(tmp)
L = self.PAR[self.HEAD[L]]
while self.HEAD[L] != self.HEAD[R]:
tmp = (self.HEAD[L], L + 1)
res.append(tmp)
L = self.PAR[self.HEAD[L]]
tmp = (self.HEAD[R], R + 1)
res.append(tmp)
R = self.PAR[self.HEAD[R]]
if L > R:
L, R = R, L
tmp = (L, R + 1)
res.append(tmp)
return res
def sub_query(self, k):
K = self.ID[k]
return (K, K + self.SUB[K])
class HLD_SegTree:
def __init__(self, e, init_val, segfunc, ide_ele, root=0):
self.hld = HLD(e, root=root)
self.ID = self.hld.ID[:]
self.N = len(e)
A = [0] * self.N
for i, idx in enumerate(self.ID):
A[idx] = init_val[i]
self.seg = SegTree(A, segfunc, ide_ele)
self.segfunc = segfunc
self.ide_ele = ide_ele
def path_query(self, l, r):
res = self.ide_ele
for _l, _r in self.hld.path_query(l, r):
res = self.segfunc(res, self.seg.query(_l, _r))
return res
def sub_query(self, l, r):
_l, _r = self.hld.sub_query(l, r)
return self.seg.query(_l, _r)
class AuxiliaryTree:
def __init__(self, e):
"""
Initializes the AuxiliaryTree with an adjacency list.
:param e: List of lists representing the adjacency list of the tree.
"""
self.n = len(e)
self.adj = e
self.depth = [0] * self.n
self.parent = [-1] * self.n
self.euler = []
self.first_occurrence = [-1] * self.n
self.log = [0] * (2 * self.n)
self.sparse_table = []
# Initialize log array
self._initialize_log()
# Perform DFS to populate Euler tour
self._recursive_dfs(0, -1, 0)
# Build the sparse table
self._build_sparse_table()
def _initialize_log(self):
"""
Initializes the log array for use in the sparse table.
"""
for i in range(2, len(self.log)):
self.log[i] = self.log[i // 2] + 1
def _recursive_dfs(self, node, par, dep):
"""
Performs a recursive depth-first search (DFS) to populate the Euler tour,
depth, and parent information for the tree.
:param node: Current node in DFS.
:param par: Parent of the current node.
:param dep: Depth of the current node.
"""
if self.first_occurrence[node] == -1:
self.first_occurrence[node] = len(self.euler)
self.euler.append(node)
self.parent[node] = par
self.depth[node] = dep
for neighbor in self.adj[node]:
if neighbor != par:
self._recursive_dfs(neighbor, node, dep + 1)
self.euler.append(node)
def _build_sparse_table(self):
"""
Builds a sparse table for range minimum queries on the Euler tour,
which is used to efficiently compute the lowest common ancestor (LCA).
"""
m = len(self.euler)
max_log = self.log[m] + 1
self.sparse_table = [[0] * m for _ in range(max_log)]
for i in range(m):
self.sparse_table[0][i] = self.euler[i]
j = 1
while (1 << j) <= m:
i = 0
while (i + (1 << j) - 1) < m:
if (
self.depth[self.sparse_table[j - 1][i]]
< self.depth[self.sparse_table[j - 1][i + (1 << (j - 1))]]
):
self.sparse_table[j][i] = self.sparse_table[j - 1][i]
else:
self.sparse_table[j][i] = self.sparse_table[j - 1][
i + (1 << (j - 1))
]
i += 1
j += 1
def lca(self, u, v):
"""
Computes the lowest common ancestor (LCA) of two nodes u and v.
:param u: First node.
:param v: Second node.
:return: The LCA of nodes u and v.
"""
left = self.first_occurrence[u]
right = self.first_occurrence[v]
if left > right:
left, right = right, left
length = right - left + 1
j = self.log[length]
# Compare depths of the two intervals in the sparse table
left_interval = self.sparse_table[j][left]
right_interval = self.sparse_table[j][right - (1 << j) + 1]
if self.depth[left_interval] < self.depth[right_interval]:
return left_interval
else:
return right_interval
def build_auxiliary_tree(self, nodes):
"""
Constructs an auxiliary tree from a subset of nodes.
:param nodes: List of nodes to include in the auxiliary tree.
:return: A dictionary representing the adjacency list of the auxiliary tree.
"""
nodes = sorted(nodes, key=lambda x: self.first_occurrence[x])
stack = []
auxiliary_tree = defaultdict(set) # Use a set to avoid duplicates
for node in nodes:
if not stack:
stack.append(node)
continue
lca_node = self.lca(stack[-1], node)
while (
len(stack) > 1
and self.first_occurrence[stack[-2]] >= self.first_occurrence[lca_node]
):
top = stack.pop()
auxiliary_tree[stack[-1]].add(top)
auxiliary_tree[top].add(stack[-1]) # Add reverse edge
if stack[-1] != lca_node:
top = stack.pop()
auxiliary_tree[lca_node].add(top)
auxiliary_tree[top].add(lca_node) # Add reverse edge
stack.append(lca_node)
stack.append(node)
# Ensure all remaining nodes in the stack are added to the auxiliary tree
while len(stack) > 1:
top = stack.pop()
auxiliary_tree[stack[-1]].add(top)
auxiliary_tree[top].add(stack[-1]) # Add reverse edge
# Convert sets back to lists for the final output
return {k: list(v) for k, v in auxiliary_tree.items()}
N = int(input())
e = [[] for _ in range(N)]
W = [0] * N
edge = defaultdict(int)
for _ in range(N - 1):
u, v, w = map(int, input().split())
e[u].append(v)
e[v].append(u)
edge[(u, v)] = w
tree = AuxiliaryTree(e)
for (u, v), w in edge.items():
if v == tree.parent[u]:
W[u] = w
else:
W[v] = w
hld = HLD_SegTree(e, W, lambda x, y: x + y, 0)
Q = int(input())
for _ in range(Q):
_, *K = map(int, input().split())
# K = [k - 1 for k in K]
at = tree.build_auxiliary_tree(K)
ans = 0
for i, v in at.items():
for j in v:
ans += hld.path_query(i, j) - W[tree.lca(i, j)]
# print(i, j, tree.lca(i, j), hld.path_query(i, j), W[tree.lca(i, j)])
# print(at, K)
print(ans // 2)
# for i in range(N):
# for j in range(i + 1, N):
# print(i, j, tree.lca(i, j))