結果

問題 No.1215 都市消滅ビーム
ユーザー gew1fw
提出日時 2025-06-12 13:53:47
言語 PyPy3
(7.3.15)
結果
MLE  
実行時間 -
コード長 3,033 bytes
コンパイル時間 291 ms
コンパイル使用メモリ 82,780 KB
実行使用メモリ 575,304 KB
最終ジャッジ日時 2025-06-12 13:55:41
合計ジャッジ時間 9,351 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 2
other AC * 13 MLE * 1 -- * 26
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
sys.setrecursionlimit(1 << 25)

def main():
    input = sys.stdin.read().split()
    idx = 0
    N, K = int(input[idx]), int(input[idx+1])
    idx += 2
    C = list(map(int, input[idx:idx+K]))
    idx += K
    D = list(map(int, input[idx:idx+K]))
    idx += K
    edges = [[] for _ in range(N+1)]
    for _ in range(N-1):
        a, b = int(input[idx]), int(input[idx+1])
        idx += 2
        edges[a].append(b)
        edges[b].append(a)
    
    # Preprocess depth and parent for LCA
    from collections import deque
    parent = [0]*(N+1)
    depth = [0]*(N+1)
    visited = [False]*(N+1)
    q = deque([1])
    visited[1] = True
    while q:
        u = q.popleft()
        for v in edges[u]:
            if not visited[v]:
                visited[v] = True
                parent[v] = u
                depth[v] = depth[u] + 1
                q.append(v)
    
    # Preprocess LCA with binary lifting
    LOG = 20
    db = [[0]*(N+1) for _ in range(LOG)]
    db[0] = parent
    for k in range(1, LOG):
        for v in range(1, N+1):
            db[k][v] = db[k-1][db[k-1][v]]
    
    def lca(u, v):
        if u == -1:
            return v
        if v == -1:
            return u
        if depth[u] < depth[v]:
            u, v = v, u
        for k in range(LOG-1, -1, -1):
            if depth[u] - (1 << k) >= depth[v]:
                u = db[k][u]
        if u == v:
            return u
        for k in range(LOG-1, -1, -1):
            if db[k][u] != db[k][v]:
                u = db[k][u]
                v = db[k][v]
        return parent[u]
    
    # Compute left_lca and right_lca
    left_lca = [-1]*(K+2)
    if K >= 1:
        left_lca[1] = C[0]
        current = C[0]
        for i in range(1, K):
            current = lca(current, C[i])
            left_lca[i+1] = current
    
    right_lca = [-1]*(K+2)
    if K >= 1:
        right_lca[K] = C[K-1]
        current = C[K-1]
        for i in range(K-2, -1, -1):
            current = lca(C[i], current)
            right_lca[i+1] = current
    
    # Prefix sums of D
    prefix = [0]*(K+1)
    for i in range(K):
        prefix[i+1] = prefix[i] + D[i]
    sum_D = prefix[K]
    
    X = []
    # Case where no shrines are removed
    if K == 0:
        X.append(-10**10)
    else:
        t_all = left_lca[K]
        if t_all == -1:
            X.append(-10**10)
        else:
            X_val = sum_D + depth[t_all]
            X.append(X_val)
    
    # Iterate over all possible intervals [l, r]
    for l in range(1, K+1):
        for r in range(l, K+1):
            sum_remove = prefix[r] - prefix[l-1]
            a = left_lca[l-1] if l-1 >= 1 else -1
            b = right_lca[r+1] if r+1 <= K else -1
            t_remain = lca(a, b)
            if t_remain == -1:
                x = -10**10
            else:
                x = (sum_D - sum_remove) + depth[t_remain]
            X.append(x)
    
    # Compute median
    X.sort()
    m = len(X)
    print(X[(m-1)//2])

if __name__ == "__main__":
    main()
0