結果

問題 No.399 動的な領主
ユーザー lam6er
提出日時 2025-03-20 20:56:36
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 644 ms / 2,000 ms
コード長 2,433 bytes
コンパイル時間 216 ms
コンパイル使用メモリ 82,144 KB
実行使用メモリ 147,600 KB
最終ジャッジ日時 2025-03-20 20:57:00
合計ジャッジ時間 7,889 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 19
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque

def main():
    sys.setrecursionlimit(1 << 25)
    input = sys.stdin.read().split()
    ptr = 0

    N = int(input[ptr])
    ptr += 1

    adj = [[] for _ in range(N + 1)]
    for _ in range(N - 1):
        u = int(input[ptr])
        v = int(input[ptr + 1])
        adj[u].append(v)
        adj[v].append(u)
        ptr += 2

    # Initialize parent, children, and depth arrays
    parent = [0] * (N + 1)
    children = [[] for _ in range(N + 1)]
    depth = [0] * (N + 1)
    root = 1
    parent[root] = 0
    depth[root] = 0
    q = deque([root])

    while q:
        u = q.popleft()
        for v in adj[u]:
            if v != parent[u]:
                parent[v] = u
                depth[v] = depth[u] + 1
                children[u].append(v)
                q.append(v)

    max_level = 20
    jump = [[0] * max_level for _ in range(N + 1)]
    for u in range(1, N + 1):
        jump[u][0] = parent[u]

    for k in range(1, max_level):
        for u in range(1, N + 1):
            jump[u][k] = jump[jump[u][k-1]][k-1]

    def get_lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        for k in range(max_level - 1, -1, -1):
            if depth[u] - (1 << k) >= depth[v]:
                u = jump[u][k]
        if u == v:
            return u
        for k in range(max_level - 1, -1, -1):
            if jump[u][k] != jump[v][k]:
                u = jump[u][k]
                v = jump[v][k]
        return jump[u][0]

    diff = [0] * (N + 1)
    Q = int(input[ptr])
    ptr += 1
    for _ in range(Q):
        a = int(input[ptr])
        b = int(input[ptr + 1])
        ptr += 2
        lca_node = get_lca(a, b)
        diff[a] += 1
        diff[b] += 1
        diff[lca_node] -= 1
        if parent[lca_node] != 0:
            diff[parent[lca_node]] -= 1

    # Post-order traversal using stack
    k_x = [0] * (N + 1)
    stack = [(root, False)]
    while stack:
        node, visited = stack.pop()
        if not visited:
            stack.append((node, True))
            for child in reversed(children[node]):
                stack.append((child, False))
        else:
            total = 0
            for child in children[node]:
                total += k_x[child]
            k_x[node] = total + diff[node]

    ans = 0
    for i in range(1, N + 1):
        ans += k_x[i] * (k_x[i] + 1) // 2

    print(ans)

if __name__ == '__main__':
    main()
0