結果

問題 No.1215 都市消滅ビーム
ユーザー gew1fw
提出日時 2025-06-12 18:50:52
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,529 bytes
コンパイル時間 141 ms
コンパイル使用メモリ 82,788 KB
実行使用メモリ 136,460 KB
最終ジャッジ日時 2025-06-12 18:51:03
合計ジャッジ時間 8,355 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 8 WA * 32
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
sys.setrecursionlimit(1 << 25)

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    ptr = 0

    N = int(data[ptr])
    ptr += 1
    K = int(data[ptr])
    ptr += 1

    C = list(map(int, data[ptr:ptr+K]))
    ptr += K
    D = list(map(int, data[ptr:ptr+K]))
    ptr += K

    edges = [[] for _ in range(N+1)]
    for _ in range(N-1):
        a = int(data[ptr])
        ptr += 1
        b = int(data[ptr])
        ptr += 1
        edges[a].append(b)
        edges[b].append(a)

    LOG = 20
    parent = [[-1]*(N+1) for _ in range(LOG)]
    depth = [0]*(N+1)

    from collections import deque
    q = deque()
    q.append(1)
    parent[0][1] = -1
    visited = [False]*(N+1)
    visited[1] = True
    while q:
        u = q.popleft()
        for v in edges[u]:
            if not visited[v]:
                visited[v] = True
                parent[0][v] = u
                depth[v] = depth[u] + 1
                q.append(v)

    for k in range(1, LOG):
        for v in range(1, N+1):
            if parent[k-1][v] != -1:
                parent[k][v] = parent[k-1][parent[k-1][v]]
            else:
                parent[k][v] = -1

    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        for k in range(LOG-1, -1, -1):
            if depth[u] - (1 << k) >= depth[v]:
                u = parent[k][u]
        if u == v:
            return u
        for k in range(LOG-1, -1, -1):
            if parent[k][u] != -1 and parent[k][u] != parent[k][v]:
                u = parent[k][u]
                v = parent[k][v]
        return parent[0][u]

    prefix_lca = [None]*(K+1)
    prefix_sum = [0]*(K+1)
    for i in range(1, K+1):
        c = C[i-1]
        d = D[i-1]
        if i == 1:
            prefix_lca[i] = c
            prefix_sum[i] = d
        else:
            prev = prefix_lca[i-1]
            curr = c
            new_lca = lca(prev, curr)
            prefix_lca[i] = new_lca
            prefix_sum[i] = prefix_sum[i-1] + d

    suffix_lca = [None]*(K+2)
    suffix_sum = [0]*(K+2)
    for i in range(K, 0, -1):
        c = C[i-1]
        d = D[i-1]
        if i == K:
            suffix_lca[i] = c
            suffix_sum[i] = d
        else:
            curr = suffix_lca[i+1]
            new_lca = lca(c, curr)
            suffix_lca[i] = new_lca
            suffix_sum[i] = suffix_sum[i+1] + d

    original_sum = prefix_sum[K]
    original_lca = prefix_lca[K]
    if original_lca is None:
        original_x = -10**18
    else:
        original_x = original_sum + depth[original_lca]

    all_x = []
    all_x.append(original_x)

    x_all_deleted = -10**18
    all_x.append(x_all_deleted)

    for i in range(1, K+1):
        left = i-1
        right = i+1
        sum_left = prefix_sum[left] if left >=0 else 0
        sum_right = suffix_sum[right] if right <= K else 0
        total_sum = sum_left + sum_right

        l_lca = prefix_lca[left] if left >=1 else None
        r_lca = suffix_lca[right] if right <= K else None

        if l_lca is None and r_lca is None:
            x = x_all_deleted
        elif l_lca is None:
            t = r_lca
        elif r_lca is None:
            t = l_lca
        else:
            t = lca(l_lca, r_lca)

        if t is None:
            x = x_all_deleted
        else:
            x = total_sum + depth[t]
        all_x.append(x)

    all_x.sort()
    m = len(all_x)
    median_pos = (m + 1) // 2 - 1
    median = all_x[median_pos]
    print(median)

if __name__ == '__main__':
    main()
0