結果
問題 |
No.1215 都市消滅ビーム
|
ユーザー |
![]() |
提出日時 | 2025-06-12 13:49:25 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 3,529 bytes |
コンパイル時間 | 418 ms |
コンパイル使用メモリ | 82,744 KB |
実行使用メモリ | 136,476 KB |
最終ジャッジ日時 | 2025-06-12 13:49:44 |
合計ジャッジ時間 | 9,756 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 8 WA * 32 |
ソースコード
import sys sys.setrecursionlimit(1 << 25) def main(): import sys input = sys.stdin.read data = input().split() ptr = 0 N = int(data[ptr]) ptr += 1 K = int(data[ptr]) ptr += 1 C = list(map(int, data[ptr:ptr+K])) ptr += K D = list(map(int, data[ptr:ptr+K])) ptr += K edges = [[] for _ in range(N+1)] for _ in range(N-1): a = int(data[ptr]) ptr += 1 b = int(data[ptr]) ptr += 1 edges[a].append(b) edges[b].append(a) LOG = 20 parent = [[-1]*(N+1) for _ in range(LOG)] depth = [0]*(N+1) from collections import deque q = deque() q.append(1) parent[0][1] = -1 visited = [False]*(N+1) visited[1] = True while q: u = q.popleft() for v in edges[u]: if not visited[v]: visited[v] = True parent[0][v] = u depth[v] = depth[u] + 1 q.append(v) for k in range(1, LOG): for v in range(1, N+1): if parent[k-1][v] != -1: parent[k][v] = parent[k-1][parent[k-1][v]] else: parent[k][v] = -1 def lca(u, v): if depth[u] < depth[v]: u, v = v, u for k in range(LOG-1, -1, -1): if depth[u] - (1 << k) >= depth[v]: u = parent[k][u] if u == v: return u for k in range(LOG-1, -1, -1): if parent[k][u] != -1 and parent[k][u] != parent[k][v]: u = parent[k][u] v = parent[k][v] return parent[0][u] prefix_lca = [None]*(K+1) prefix_sum = [0]*(K+1) for i in range(1, K+1): c = C[i-1] d = D[i-1] if i == 1: prefix_lca[i] = c prefix_sum[i] = d else: prev = prefix_lca[i-1] curr = c new_lca = lca(prev, curr) prefix_lca[i] = new_lca prefix_sum[i] = prefix_sum[i-1] + d suffix_lca = [None]*(K+2) suffix_sum = [0]*(K+2) for i in range(K, 0, -1): c = C[i-1] d = D[i-1] if i == K: suffix_lca[i] = c suffix_sum[i] = d else: curr = suffix_lca[i+1] new_lca = lca(c, curr) suffix_lca[i] = new_lca suffix_sum[i] = suffix_sum[i+1] + d original_sum = prefix_sum[K] original_lca = prefix_lca[K] if original_lca is None: original_x = -10**18 else: original_x = original_sum + depth[original_lca] all_x = [] all_x.append(original_x) x_all_deleted = -10**18 all_x.append(x_all_deleted) for i in range(1, K+1): left = i-1 right = i+1 sum_left = prefix_sum[left] if left >=0 else 0 sum_right = suffix_sum[right] if right <= K else 0 total_sum = sum_left + sum_right l_lca = prefix_lca[left] if left >=1 else None r_lca = suffix_lca[right] if right <= K else None if l_lca is None and r_lca is None: x = x_all_deleted elif l_lca is None: t = r_lca elif r_lca is None: t = l_lca else: t = lca(l_lca, r_lca) if t is None: x = x_all_deleted else: x = total_sum + depth[t] all_x.append(x) all_x.sort() m = len(all_x) median_pos = (m + 1) // 2 - 1 median = all_x[median_pos] print(median) if __name__ == '__main__': main()