結果

問題 No.1197 モンスターショー
ユーザー lam6er
提出日時 2025-03-20 20:57:07
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,903 bytes
コンパイル時間 396 ms
コンパイル使用メモリ 82,792 KB
実行使用メモリ 121,448 KB
最終ジャッジ日時 2025-03-20 20:57:15
合計ジャッジ時間 7,181 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 2
other AC * 7 TLE * 1 -- * 33
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque

def main():
    sys.setrecursionlimit(1 << 25)
    input = sys.stdin.read().split()
    ptr = 0
    N, K, Q = map(int, input[ptr:ptr+3])
    ptr += 3
    C = list(map(int, input[ptr:ptr+K]))
    ptr += K

    # Build adjacency list
    adj = [[] for _ in range(N+1)]
    for _ in range(N-1):
        a, b = map(int, input[ptr:ptr+2])
        ptr +=2
        adj[a].append(b)
        adj[b].append(a)

    # Compute parent and children using BFS
    root = 1
    parent = [0]*(N+1)
    children = [[] for _ in range(N+1)]
    visited = [False]*(N+1)
    q = deque([root])
    visited[root] = True
    parent[root] = 0

    while q:
        u = q.popleft()
        for v in adj[u]:
            if not visited[v] and v != parent[u]:
                visited[v] = True
                parent[v] = u
                children[u].append(v)
                q.append(v)

    # Compute in_time and out_time using iterative DFS
    in_time = [0]*(N+1)
    out_time = [0]*(N+1)
    time = 1
    stack = [(root, False)]

    while stack:
        node, visited_flag = stack.pop()
        if visited_flag:
            out_time[node] = time
            time +=1
            continue
        in_time[node] = time
        time +=1
        stack.append( (node, True) )
        # Add children in reverse order to process them in order
        for child in reversed(children[node]):
            stack.append( (child, False) )

    max_time = time-1
    bit = BIT(max_time)

    # Depth computation
    depth = [0]*(N+1)
    q = deque([root])
    visited = [False]*(N+1)
    visited[root] = True
    while q:
        u = q.popleft()
        for v in children[u]:
            if not visited[v]:
                visited[v] = True
                depth[v] = depth[u] + 1
                q.append(v)

    # Initialize BIT with initial positions
    sum_depth_x = 0
    current_pos = C.copy()  # 1-based indexes?
    for c in current_pos:
        bit.update(in_time[c], 1)
        sum_depth_x += depth[c]

    # Process queries
    for _ in range(Q):
        query = input[ptr:ptr+3]
        if query[0] == '1':
            # Type 1: move
            p = int(query[1]) -1  # 0-based index
            d = int(query[2])
            old_pos = current_pos[p]
            new_pos = d

            # Update BIT
            bit.update(in_time[old_pos], -1)
            bit.update(in_time[new_pos], +1)

            # Update sum_depth_x
            sum_depth_x -= depth[old_pos]
            sum_depth_x += depth[new_pos]

            current_pos[p] = new_pos
            ptr +=3
        else:
            # Type 2: query
            e = int(query[1])
            ptr +=2

            # Compute sum_lca_depth_e
            in_e = in_time[e]
            out_e = out_time[e]
            cnt_e = bit.query_range(in_e, out_e)
            sum_lca = cnt_e * depth[e]

            current = e
            while True:
                p = parent[current]
                if p ==0:
                    break
                # Get count[p] and count[current]
                cnt_p = bit.query_range(in_time[p], out_time[p])
                cnt_current = bit.query_range(in_time[current], out_time[current])
                sum_lca += (cnt_p - cnt_current) * depth[p]

                current = p

            total = sum_depth_x + K * depth[e] - 2 * sum_lca
            print(total)

class BIT:
    def __init__(self, size):
        self.size = size
        self.tree = [0]*(self.size +2)  # 1-based indexing

    def update(self, idx, delta):
        while idx <= self.size:
            self.tree[idx] += delta
            idx += idx & -idx

    def query(self, idx):
        res =0
        while idx >0:
            res += self.tree[idx]
            idx -= idx & -idx
        return res

    def query_range(self, l, r):
        return self.query(r) - self.query(l-1)

if __name__ == "__main__":
    main()
0