結果
問題 |
No.3272 Separate Contractions
|
ユーザー |
![]() |
提出日時 | 2025-08-10 17:53:59 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,660 ms / 3,000 ms |
コード長 | 3,203 bytes |
コンパイル時間 | 415 ms |
コンパイル使用メモリ | 82,040 KB |
実行使用メモリ | 341,612 KB |
最終ジャッジ日時 | 2025-09-11 00:35:36 |
合計ジャッジ時間 | 33,171 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 43 |
ソースコード
__import__("pypyjit").set_param('max_unroll_recursion=-1') from collections import deque __import__("sys").setrecursionlimit(1000000) def get_dist(N, G, start): dist = [-1 for _ in range(N)] dist[start] = 0 q = deque() q.append(start) while len(q) > 0: x = q.popleft() for y in G[x]: if dist[y] == -1: dist[y] = dist[x] + 1 q.append(y) return dist def get_center(N, G): dist = get_dist(N, G, 0) max_dist, argmax = -1, -1 for i, dist_i in enumerate(dist): if max_dist < dist_i: max_dist, argmax = dist_i, i dist = get_dist(N, G, argmax) pre = [-1 for _ in range(N)] for i in range(N): for j in G[i]: if dist[j] == dist[i] - 1: pre[i] = j max_dist, argmax = -1, -1 for i, dist_i in enumerate(dist): if max_dist < dist_i: max_dist, argmax = dist_i, i path = [] cur = argmax while cur != -1: path.append(cur) cur = pre[cur] if max_dist % 2: return max_dist, path[max_dist//2], path[max_dist//2+1] else: return max_dist, path[max_dist//2], -1 def solve(N, E, G): diameter, center1, center2 = get_center(N, G) half = diameter // 2 dep = [0 for _ in range(N)] sz = [0 for _ in range(N)] cnt = [0 for _ in range(N)] top = [0 for _ in range(N)] def dfs(u, par, dep_cur, top_cur): nonlocal G, half, dep, sz, cnt, top dep[u] = dep_cur top[u] = top_cur sz[u] = 1 if dep[u] == half: cnt[u] = 1 for v in G[u]: if v != par: dfs(v, u, dep_cur + 1, top_cur) cnt[u] += cnt[v] sz[u] += sz[v] ans_before = 0 ans = [0 for _ in range(N-1)] if diameter == half * 2: occur_half = sum_sz = 0 for r in G[center1]: dfs(r, center1, 1, r) if cnt[r] > 0: occur_half += 1 sum_sz += sz[r] ans_before = sum(dep[i] + half for i in range(N)) for i, (u, v) in enumerate(E): if dep[u] > dep[v]: u, v = v, u ans[i] = ans_before - (dep[u] + half) - sz[v] if occur_half == 2 and cnt[top[v]] > 0 and cnt[top[v]] == cnt[v]: ans[i] -= sum_sz - sz[top[v]] else: dfs(center1, center2, 0, center1) dfs(center2, center1, 0, center2) ans_before = sum(dep[i] + half + 1 for i in range(N)) for i, (u, v) in enumerate(E): ans[i] = ans_before if (u, v) == (center1, center2) or (u, v) == (center2, center1): ans[i] -= half + N else: if dep[u] > dep[v]: u, v = v, u ans[i] -= dep[u] + half + 1 + sz[v] if cnt[top[v]] == cnt[v]: ans[i] -= N - sz[top[v]] return ans N = int(input()) E = [] G = [[] for _ in range(N)] for i in range(N-1): u, v = map(int, input().split()) u -= 1; v -= 1 E.append((u, v)) G[u].append(v) G[v].append(u) ans = solve(N, E, G) for x in ans: print(x)