結果
問題 |
No.1197 モンスターショー
|
ユーザー |
![]() |
提出日時 | 2025-05-14 12:55:13 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 4,697 bytes |
コンパイル時間 | 182 ms |
コンパイル使用メモリ | 82,628 KB |
実行使用メモリ | 129,020 KB |
最終ジャッジ日時 | 2025-05-14 12:55:56 |
合計ジャッジ時間 | 5,873 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | -- * 2 |
other | AC * 7 TLE * 1 -- * 33 |
ソースコード
import sys from bisect import bisect sys.setrecursionlimit(1 << 25) def main(): input = sys.stdin.read().split() ptr = 0 N, K, Q = int(input[ptr]), int(input[ptr+1]), int(input[ptr+2]) ptr +=3 C = list(map(int, input[ptr:ptr+K])) ptr +=K adj = [[] for _ in range(N+1)] for _ in range(N-1): a = int(input[ptr]) b = int(input[ptr+1]) adj[a].append(b) adj[b].append(a) ptr +=2 root = 1 parent = [0]*(N+1) depth = [0]*(N+1) in_time = [0]*(N+1) out_time = [0]*(N+1) children = [[] for _ in range(N+1)] time = 0 stack = [(root, False)] while stack: node, visited = stack.pop() if visited: out_time[node] = time continue in_time[node] = time time +=1 stack.append( (node, True) ) temp = [] for neighbor in adj[node]: if neighbor != parent[node]: parent[neighbor] = node depth[neighbor] = depth[node] +1 temp.append(neighbor) temp.sort() for child in reversed(temp): stack.append( (child, False) ) children[node] = temp size = [1]*(N+1) heavy_child = [0]*(N+1) stack = [(root, False)] while stack: node, visited = stack.pop() if visited: max_size = 0 for child in children[node]: size[node] += size[child] if size[child] > max_size: max_size = size[child] heavy_child[node] = child continue stack.append( (node, True) ) for child in reversed(children[node]): stack.append( (child, False) ) top = [0]*(N+1) stack = [(root, root)] while stack: node, current_top = stack.pop() top[node] = current_top for child in reversed(children[node]): if child == heavy_child[node]: stack.append( (child, current_top) ) else: stack.append( (child, child) ) class BIT: def __init__(self, size): self.n = size self.tree = [0]*(self.n +2) def update(self, idx, delta): idx +=1 while idx <= self.n +1: self.tree[idx] += delta idx += idx & -idx def query_prefix(self, idx): idx +=1 res = 0 while idx >0: res += self.tree[idx] idx -= idx & -idx return res def query_range(self, l, r): return self.query_prefix(r) - self.query_prefix(l-1) bit = BIT(time) sum_depth_s = 0 cnt = [0]*(N+1) for c in C: cnt[c] +=1 sum_depth_s += depth[c] for i in range(1, N+1): if cnt[i]: bit.update(in_time[i], cnt[i]) for _ in range(Q): query = input[ptr] ptr +=1 if query == '1': p = int(input[ptr])-1 d = int(input[ptr+1]) ptr +=2 old_c = C[p] bit.update(in_time[old_c], -1) sum_depth_s -= depth[old_c] C[p] = d bit.update(in_time[d], 1) sum_depth_s += depth[d] else: e = int(input[ptr]) ptr +=1 current_e = e sum_lca = 0 while current_e != 0: t = top[current_e] v = current_e while True: cnt_sub_v = bit.query_range(in_time[v], out_time[v]-1) left = 0 right = len(children[v]) -1 child_e = None while left <= right: mid = (left + right) //2 ch = children[v][mid] if in_time[ch] <= in_time[e] <= out_time[ch]-1: child_e = ch break elif in_time[e] < in_time[ch]: right = mid -1 else: left = mid +1 if child_e is not None: cnt_sub_child = bit.query_range(in_time[child_e], out_time[child_e]-1) sum_lca += (cnt_sub_v - cnt_sub_child) * depth[v] else: sum_lca += cnt_sub_v * depth[v] if v == t: break v = parent[v] current_e = parent[t] total = sum_depth_s + K * depth[e] - 2 * sum_lca print(total) if __name__ == '__main__': main()