結果

問題 No.1197 モンスターショー
ユーザー qwewe
提出日時 2025-05-14 12:55:13
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,374 bytes
コンパイル時間 321 ms
コンパイル使用メモリ 82,104 KB
実行使用メモリ 56,516 KB
最終ジャッジ日時 2025-05-14 12:55:56
合計ジャッジ時間 5,866 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 2
other AC * 7 TLE * 1 -- * 33
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, K, Q = map(int, input[ptr:ptr+3])
    ptr += 3
    C = list(map(int, input[ptr:ptr+K]))
    ptr += K
    adj = [[] for _ in range(N+1)]
    for _ in range(N-1):
        a, b = map(int, input[ptr:ptr+2])
        ptr += 2
        adj[a].append(b)
        adj[b].append(a)
    
    root = 1
    parent = [0] * (N + 1)
    depth = [0] * (N + 1)
    visited = [False] * (N + 1)
    q = deque([root])
    visited[root] = True
    while q:
        u = q.popleft()
        for v in adj[u]:
            if not visited[v] and v != parent[u]:
                parent[v] = u
                depth[v] = depth[u] + 1
                visited[v] = True
                q.append(v)
    
    children = [[] for _ in range(N + 1)]
    for v in range(2, N + 1):
        p = parent[v]
        children[p].append(v)
    
    subtree = [{} for _ in range(N + 1)]
    sum_children = [0] * (N + 1)
    for p in range(1, N + 1):
        for c in children[p]:
            subtree[p][c] = 0
    
    cnt = [0] * (N + 1)
    sum_depth = 0
    slime_pos = []
    for c in C:
        slime_pos.append(c)
        cnt[c] += 1
        sum_depth += depth[c]
        current = c
        while current != root:
            p = parent[current]
            subtree[p][current] += 1
            sum_children[p] += 1
            current = p
    
    for _ in range(Q):
        query = input[ptr]
        ptr += 1
        if query == '1':
            p = int(input[ptr]) - 1
            ptr += 1
            d = int(input[ptr])
            ptr += 1
            old_d = slime_pos[p]
            new_d = d
            slime_pos[p] = new_d
            
            sum_depth -= depth[old_d]
            cnt[old_d] -= 1
            current = old_d
            while current != root:
                p_node = parent[current]
                subtree[p_node][current] -= 1
                sum_children[p_node] -= 1
                current = p_node
            
            sum_depth += depth[new_d]
            cnt[new_d] += 1
            current = new_d
            while current != root:
                p_node = parent[current]
                subtree[p_node][current] += 1
                sum_children[p_node] += 1
                current = p_node
        else:
            e = int(input[ptr])
            ptr += 1
            path = []
            current = e
            while current != root:
                path.append(current)
                current = parent[current]
            path.append(root)
            path = path[::-1]
            
            sum_lca_depth = 0
            for i in range(len(path)):
                a = path[i]
                if i == len(path) - 1:
                    sum_sub = cnt[a] + sum_children[a]
                    sum_lca_depth += depth[a] * sum_sub
                else:
                    next_node = path[i + 1]
                    if parent[next_node] == a:
                        c = next_node
                        sum_sub = cnt[a] + sum_children[a] - subtree[a].get(c, 0)
                    else:
                        sum_sub = cnt[a] + sum_children[a]
                    sum_lca_depth += depth[a] * sum_sub
            
            total = sum_depth + depth[e] * K - 2 * sum_lca_depth
            print(total)

if __name__ == '__main__':
    main()
0