結果
| 問題 | No.399 動的な領主 |
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-20 20:56:36 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 644 ms / 2,000 ms |
| コード長 | 2,433 bytes |
| コンパイル時間 | 216 ms |
| コンパイル使用メモリ | 82,144 KB |
| 実行使用メモリ | 147,600 KB |
| 最終ジャッジ日時 | 2025-03-20 20:57:00 |
| 合計ジャッジ時間 | 7,889 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| other | AC * 19 |
ソースコード
import sys
from collections import deque
def main():
sys.setrecursionlimit(1 << 25)
input = sys.stdin.read().split()
ptr = 0
N = int(input[ptr])
ptr += 1
adj = [[] for _ in range(N + 1)]
for _ in range(N - 1):
u = int(input[ptr])
v = int(input[ptr + 1])
adj[u].append(v)
adj[v].append(u)
ptr += 2
# Initialize parent, children, and depth arrays
parent = [0] * (N + 1)
children = [[] for _ in range(N + 1)]
depth = [0] * (N + 1)
root = 1
parent[root] = 0
depth[root] = 0
q = deque([root])
while q:
u = q.popleft()
for v in adj[u]:
if v != parent[u]:
parent[v] = u
depth[v] = depth[u] + 1
children[u].append(v)
q.append(v)
max_level = 20
jump = [[0] * max_level for _ in range(N + 1)]
for u in range(1, N + 1):
jump[u][0] = parent[u]
for k in range(1, max_level):
for u in range(1, N + 1):
jump[u][k] = jump[jump[u][k-1]][k-1]
def get_lca(u, v):
if depth[u] < depth[v]:
u, v = v, u
for k in range(max_level - 1, -1, -1):
if depth[u] - (1 << k) >= depth[v]:
u = jump[u][k]
if u == v:
return u
for k in range(max_level - 1, -1, -1):
if jump[u][k] != jump[v][k]:
u = jump[u][k]
v = jump[v][k]
return jump[u][0]
diff = [0] * (N + 1)
Q = int(input[ptr])
ptr += 1
for _ in range(Q):
a = int(input[ptr])
b = int(input[ptr + 1])
ptr += 2
lca_node = get_lca(a, b)
diff[a] += 1
diff[b] += 1
diff[lca_node] -= 1
if parent[lca_node] != 0:
diff[parent[lca_node]] -= 1
# Post-order traversal using stack
k_x = [0] * (N + 1)
stack = [(root, False)]
while stack:
node, visited = stack.pop()
if not visited:
stack.append((node, True))
for child in reversed(children[node]):
stack.append((child, False))
else:
total = 0
for child in children[node]:
total += k_x[child]
k_x[node] = total + diff[node]
ans = 0
for i in range(1, N + 1):
ans += k_x[i] * (k_x[i] + 1) // 2
print(ans)
if __name__ == '__main__':
main()
lam6er