結果
| 問題 |
No.1197 モンスターショー
|
| コンテスト | |
| ユーザー |
qwewe
|
| 提出日時 | 2025-05-14 12:55:13 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 4,697 bytes |
| コンパイル時間 | 182 ms |
| コンパイル使用メモリ | 82,628 KB |
| 実行使用メモリ | 129,020 KB |
| 最終ジャッジ日時 | 2025-05-14 12:55:56 |
| 合計ジャッジ時間 | 5,873 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | -- * 2 |
| other | AC * 7 TLE * 1 -- * 33 |
ソースコード
import sys
from bisect import bisect
sys.setrecursionlimit(1 << 25)
def main():
input = sys.stdin.read().split()
ptr = 0
N, K, Q = int(input[ptr]), int(input[ptr+1]), int(input[ptr+2])
ptr +=3
C = list(map(int, input[ptr:ptr+K]))
ptr +=K
adj = [[] for _ in range(N+1)]
for _ in range(N-1):
a = int(input[ptr])
b = int(input[ptr+1])
adj[a].append(b)
adj[b].append(a)
ptr +=2
root = 1
parent = [0]*(N+1)
depth = [0]*(N+1)
in_time = [0]*(N+1)
out_time = [0]*(N+1)
children = [[] for _ in range(N+1)]
time = 0
stack = [(root, False)]
while stack:
node, visited = stack.pop()
if visited:
out_time[node] = time
continue
in_time[node] = time
time +=1
stack.append( (node, True) )
temp = []
for neighbor in adj[node]:
if neighbor != parent[node]:
parent[neighbor] = node
depth[neighbor] = depth[node] +1
temp.append(neighbor)
temp.sort()
for child in reversed(temp):
stack.append( (child, False) )
children[node] = temp
size = [1]*(N+1)
heavy_child = [0]*(N+1)
stack = [(root, False)]
while stack:
node, visited = stack.pop()
if visited:
max_size = 0
for child in children[node]:
size[node] += size[child]
if size[child] > max_size:
max_size = size[child]
heavy_child[node] = child
continue
stack.append( (node, True) )
for child in reversed(children[node]):
stack.append( (child, False) )
top = [0]*(N+1)
stack = [(root, root)]
while stack:
node, current_top = stack.pop()
top[node] = current_top
for child in reversed(children[node]):
if child == heavy_child[node]:
stack.append( (child, current_top) )
else:
stack.append( (child, child) )
class BIT:
def __init__(self, size):
self.n = size
self.tree = [0]*(self.n +2)
def update(self, idx, delta):
idx +=1
while idx <= self.n +1:
self.tree[idx] += delta
idx += idx & -idx
def query_prefix(self, idx):
idx +=1
res = 0
while idx >0:
res += self.tree[idx]
idx -= idx & -idx
return res
def query_range(self, l, r):
return self.query_prefix(r) - self.query_prefix(l-1)
bit = BIT(time)
sum_depth_s = 0
cnt = [0]*(N+1)
for c in C:
cnt[c] +=1
sum_depth_s += depth[c]
for i in range(1, N+1):
if cnt[i]:
bit.update(in_time[i], cnt[i])
for _ in range(Q):
query = input[ptr]
ptr +=1
if query == '1':
p = int(input[ptr])-1
d = int(input[ptr+1])
ptr +=2
old_c = C[p]
bit.update(in_time[old_c], -1)
sum_depth_s -= depth[old_c]
C[p] = d
bit.update(in_time[d], 1)
sum_depth_s += depth[d]
else:
e = int(input[ptr])
ptr +=1
current_e = e
sum_lca = 0
while current_e != 0:
t = top[current_e]
v = current_e
while True:
cnt_sub_v = bit.query_range(in_time[v], out_time[v]-1)
left = 0
right = len(children[v]) -1
child_e = None
while left <= right:
mid = (left + right) //2
ch = children[v][mid]
if in_time[ch] <= in_time[e] <= out_time[ch]-1:
child_e = ch
break
elif in_time[e] < in_time[ch]:
right = mid -1
else:
left = mid +1
if child_e is not None:
cnt_sub_child = bit.query_range(in_time[child_e], out_time[child_e]-1)
sum_lca += (cnt_sub_v - cnt_sub_child) * depth[v]
else:
sum_lca += cnt_sub_v * depth[v]
if v == t:
break
v = parent[v]
current_e = parent[t]
total = sum_depth_s + K * depth[e] - 2 * sum_lca
print(total)
if __name__ == '__main__':
main()
qwewe