結果

問題 No.1197 モンスターショー
ユーザー qwewe
提出日時 2025-05-14 12:55:13
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,697 bytes
コンパイル時間 182 ms
コンパイル使用メモリ 82,628 KB
実行使用メモリ 129,020 KB
最終ジャッジ日時 2025-05-14 12:55:56
合計ジャッジ時間 5,873 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 2
other AC * 7 TLE * 1 -- * 33
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from bisect import bisect

sys.setrecursionlimit(1 << 25)

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, K, Q = int(input[ptr]), int(input[ptr+1]), int(input[ptr+2])
    ptr +=3
    C = list(map(int, input[ptr:ptr+K]))
    ptr +=K
    adj = [[] for _ in range(N+1)]
    for _ in range(N-1):
        a = int(input[ptr])
        b = int(input[ptr+1])
        adj[a].append(b)
        adj[b].append(a)
        ptr +=2
    root = 1
    parent = [0]*(N+1)
    depth = [0]*(N+1)
    in_time = [0]*(N+1)
    out_time = [0]*(N+1)
    children = [[] for _ in range(N+1)]
    time = 0
    stack = [(root, False)]
    while stack:
        node, visited = stack.pop()
        if visited:
            out_time[node] = time
            continue
        in_time[node] = time
        time +=1
        stack.append( (node, True) )
        temp = []
        for neighbor in adj[node]:
            if neighbor != parent[node]:
                parent[neighbor] = node
                depth[neighbor] = depth[node] +1
                temp.append(neighbor)
        temp.sort()
        for child in reversed(temp):
            stack.append( (child, False) )
        children[node] = temp
    size = [1]*(N+1)
    heavy_child = [0]*(N+1)
    stack = [(root, False)]
    while stack:
        node, visited = stack.pop()
        if visited:
            max_size = 0
            for child in children[node]:
                size[node] += size[child]
                if size[child] > max_size:
                    max_size = size[child]
                    heavy_child[node] = child
            continue
        stack.append( (node, True) )
        for child in reversed(children[node]):
            stack.append( (child, False) )
    top = [0]*(N+1)
    stack = [(root, root)]
    while stack:
        node, current_top = stack.pop()
        top[node] = current_top
        for child in reversed(children[node]):
            if child == heavy_child[node]:
                stack.append( (child, current_top) )
            else:
                stack.append( (child, child) )
    class BIT:
        def __init__(self, size):
            self.n = size
            self.tree = [0]*(self.n +2)
        def update(self, idx, delta):
            idx +=1
            while idx <= self.n +1:
                self.tree[idx] += delta
                idx += idx & -idx
        def query_prefix(self, idx):
            idx +=1
            res = 0
            while idx >0:
                res += self.tree[idx]
                idx -= idx & -idx
            return res
        def query_range(self, l, r):
            return self.query_prefix(r) - self.query_prefix(l-1)
    bit = BIT(time)
    sum_depth_s = 0
    cnt = [0]*(N+1)
    for c in C:
        cnt[c] +=1
        sum_depth_s += depth[c]
    for i in range(1, N+1):
        if cnt[i]:
            bit.update(in_time[i], cnt[i])
    for _ in range(Q):
        query = input[ptr]
        ptr +=1
        if query == '1':
            p = int(input[ptr])-1
            d = int(input[ptr+1])
            ptr +=2
            old_c = C[p]
            bit.update(in_time[old_c], -1)
            sum_depth_s -= depth[old_c]
            C[p] = d
            bit.update(in_time[d], 1)
            sum_depth_s += depth[d]
        else:
            e = int(input[ptr])
            ptr +=1
            current_e = e
            sum_lca = 0
            while current_e != 0:
                t = top[current_e]
                v = current_e
                while True:
                    cnt_sub_v = bit.query_range(in_time[v], out_time[v]-1)
                    left = 0
                    right = len(children[v]) -1
                    child_e = None
                    while left <= right:
                        mid = (left + right) //2
                        ch = children[v][mid]
                        if in_time[ch] <= in_time[e] <= out_time[ch]-1:
                            child_e = ch
                            break
                        elif in_time[e] < in_time[ch]:
                            right = mid -1
                        else:
                            left = mid +1
                    if child_e is not None:
                        cnt_sub_child = bit.query_range(in_time[child_e], out_time[child_e]-1)
                        sum_lca += (cnt_sub_v - cnt_sub_child) * depth[v]
                    else:
                        sum_lca += cnt_sub_v * depth[v]
                    if v == t:
                        break
                    v = parent[v]
                current_e = parent[t]
            total = sum_depth_s + K * depth[e] - 2 * sum_lca
            print(total)

if __name__ == '__main__':
    main()
0