結果

問題 No.1197 モンスターショー
ユーザー lam6er
提出日時 2025-04-09 21:05:11
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,515 bytes
コンパイル時間 525 ms
コンパイル使用メモリ 81,888 KB
実行使用メモリ 55,996 KB
最終ジャッジ日時 2025-04-09 21:06:52
合計ジャッジ時間 5,698 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 2
other AC * 7 TLE * 1 -- * 33
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
from collections import deque

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, K, Q = map(int, input[ptr:ptr+3])
    ptr += 3
    C = list(map(lambda x: int(x)-1, input[ptr:ptr+K]))
    ptr += K
    edges = [[] for _ in range(N)]
    for _ in range(N-1):
        a = int(input[ptr])-1
        b = int(input[ptr+1])-1
        ptr += 2
        edges[a].append(b)
        edges[b].append(a)
    
    # BFS to build parent, children, depth
    root = 0
    parent = [-1]*N
    children = [[] for _ in range(N)]
    depth = [0]*N
    visited = [False]*N
    q = deque([root])
    visited[root] = True
    while q:
        u = q.popleft()
        for v in edges[u]:
            if not visited[v] and v != parent[u]:
                parent[v] = u
                children[u].append(v)
                depth[v] = depth[u] + 1
                visited[v] = True
                q.append(v)
    
    # Non-recursive Euler Tour for in_time and out_time
    in_time = [0]*N
    out_time = [0]*N
    time = 0
    stack = [(root, False)]
    while stack:
        node, done = stack.pop()
        if done:
            out_time[node] = time - 1
            continue
        in_time[node] = time
        time += 1
        stack.append((node, True))
        # Add children in reversed order for correct left-right in in_time
        for child in reversed(children[node]):
            stack.append((child, False))
    
    # BIT implementation
    class BIT:
        def __init__(self, size):
            self.n = size + 2
            self.tree = [0] * (self.n + 2)
        
        def update(self, idx, delta):
            idx += 1  # Convert to 1-based index
            while idx <= self.n:
                self.tree[idx] += delta
                idx += idx & -idx
        
        def query(self, idx):
            idx += 1  # 1-based
            res = 0
            while idx > 0:
                res += self.tree[idx]
                idx -= idx & -idx
            return res
        
        def range_query(self, l, r):
            return self.query(r) - self.query(l - 1)
    
    bit = BIT(time)
    sum_depth = 0
    
    # Initialize slimes
    for c in C:
        bit.update(in_time[c], 1)
        sum_depth += depth[c]
    
    # Process queries
    for _ in range(Q):
        query = input[ptr]
        ptr += 1
        if query == '1':
            # Update query
            p = int(input[ptr]) - 1
            ptr += 1
            d = int(input[ptr]) - 1
            ptr += 1
            old_pos = C[p]
            new_pos = d
            # Update BIT
            bit.update(in_time[old_pos], -1)
            bit.update(in_time[new_pos], 1)
            # Update sum_depth
            sum_depth -= depth[old_pos]
            sum_depth += depth[new_pos]
            C[p] = new_pos
        else:
            # Summon query
            e = int(input[ptr]) - 1
            ptr += 1
            sum_lca = 0
            current = e
            # Add contribution of e itself
            cnt_e = bit.range_query(in_time[current], out_time[current])
            sum_lca += depth[current] * cnt_e
            # Process ancestors up to root
            while current != root:
                p_node = parent[current]
                # Find which child of p_node is the ancestor of current (part of the path)
                target_in = in_time[current]
                low = 0
                high = len(children[p_node]) - 1
                found = -1
                while low <= high:
                    mid = (low + high) // 2
                    child = children[p_node][mid]
                    if in_time[child] <= target_in <= out_time[child]:
                        found = mid
                        break
                    elif in_time[child] > target_in:
                        high = mid - 1
                    else:
                        low = mid + 1
                if found != -1:
                    s_child = children[p_node][found]
                    s_count = bit.range_query(in_time[s_child], out_time[s_child])
                else:
                    s_count = 0
                p_sub_count = bit.range_query(in_time[p_node], out_time[p_node])
                other_cnt = p_sub_count - s_count
                sum_lca += depth[p_node] * other_cnt
                current = p_node
            # Calculate total distance
            total = K * depth[e] + sum_depth - 2 * sum_lca
            print(total)

if __name__ == "__main__":
    main()
0