結果

問題 No.399 動的な領主
ユーザー lam6er
提出日時 2025-03-26 15:59:05
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 5,830 bytes
コンパイル時間 169 ms
コンパイル使用メモリ 82,556 KB
実行使用メモリ 145,480 KB
最終ジャッジ日時 2025-03-26 15:59:51
合計ジャッジ時間 5,595 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 6 TLE * 1 -- * 12
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
sys.setrecursionlimit(1 << 25)

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr +=1

    # Build adjacency list
    adj = [[] for _ in range(N+1)]
    for _ in range(N-1):
        u = int(input[ptr])
        v = int(input[ptr+1])
        adj[u].append(v)
        adj[v].append(u)
        ptr +=2

    # Compute parent, depth, size, heavy child using iterative DFS
    parent = [0]*(N+1)
    depth = [0]*(N+1)
    size = [1]*(N+1)
    heavy = [-1]*(N+1)

    # First pass to compute parent and depth using BFS
    from collections import deque
    q = deque()
    root = 1
    q.append(root)
    parent[root] = -1
    depth[root] = 0
    while q:
        v = q.popleft()
        for u in adj[v]:
            if parent[v] != u:
                parent[u] = v
                depth[u] = depth[v] +1
                q.append(u)

    # Second pass to compute size and heavy child using iterative post-order traversal
    stack = [(root, False)]
    order = []
    while stack:
        v, visited = stack.pop()
        if visited:
            order.append(v)
            max_size = -1
            for u in adj[v]:
                if u != parent[v]:
                    size[v] += size[u]
                    if size[u] > max_size:
                        max_size = size[u]
                        heavy[v] = u
        else:
            stack.append((v, True))
            # Push children in reverse order to process them in order
            for u in reversed(adj[v]):
                if u != parent[v]:
                    stack.append((u, False))

    # Assign in-time and head using iterative DFS
    in_time = [0]*(N+1)
    head = [0]*(N+1)
    current_time =0
    stack = [(root, root)]
    while stack:
        v, h = stack.pop()
        head[v] = h
        in_time[v] = current_time
        current_time +=1
        # Push children, heavy child last
        for u in adj[v]:
            if u != parent[v] and u != heavy[v]:
                stack.append((u, u))
        if heavy[v] != -1:
            stack.append((heavy[v], h))

    # Build LCA binary lifting table
    LOG = 20
    up = [[-1]*(N+1) for _ in range(LOG)]
    up[0] = parent
    for k in range(1, LOG):
        for v in range(1, N+1):
            if up[k-1][v] != -1:
                up[k][v] = up[k-1][up[k-1][v]]
            else:
                up[k][v] = -1

    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        # Bring u up to depth of v
        for k in range(LOG-1, -1, -1):
            if depth[u] - (1 << k) >= depth[v]:
                u = up[k][u]
        if u == v:
            return u
        for k in range(LOG-1, -1, -1):
            if up[k][u] != up[k][v]:
                u = up[k][u]
                v = up[k][v]
        return up[0][u]

    # Segment Tree with Lazy Propagation
    class SegmentTree:
        def __init__(self, size):
            self.n = 1
            while self.n < size:
                self.n <<=1
            self.size = self.n
            self.tree = [0]*(2*self.n)
            self.lazy = [0]*(2*self.n)

        def push(self, node, l, r):
            if self.lazy[node] !=0:
                mid = (l + r)//2
                left_node = node *2
                right_node = node*2 +1
                # Update left child
                self.tree[left_node] += self.lazy[node] * (mid - l +1)
                self.lazy[left_node] += self.lazy[node]
                # Update right child
                self.tree[right_node] += self.lazy[node] * (r - mid)
                self.lazy[right_node] += self.lazy[node]
                # Clear lazy
                self.lazy[node] =0

        def update_range(self, a, b, val, node=1, l=0, r=None):
            if r is None:
                r = self.n -1
            if a > r or b < l:
                return
            if a <= l and r <= b:
                self.tree[node] += val * (r - l +1)
                self.lazy[node] += val
                return
            self.push(node, l, r)
            mid = (l + r)//2
            self.update_range(a, b, val, node*2, l, mid)
            self.update_range(a, b, val, node*2+1, mid+1, r)
            self.tree[node] = self.tree[node*2] + self.tree[node*2+1]

        def query_range(self, a, b, node=1, l=0, r=None):
            if r is None:
                r = self.n -1
            if a > r or b < l:
                return 0
            if a <= l and r <= b:
                return self.tree[node]
            self.push(node, l, r)
            mid = (l + r)//2
            return self.query_range(a, b, node*2, l, mid) + self.query_range(a, b, node*2+1, mid+1, r)

    seg = SegmentTree(N)

    def query_path(u, v):
        res =0
        while head[u] != head[v]:
            if depth[head[u]] < depth[head[v]]:
                u, v = v, u
            res += seg.query_range(in_time[head[u]], in_time[u])
            u = parent[head[u]]
        if depth[u] > depth[v]:
            u, v = v, u
        res += seg.query_range(in_time[u], in_time[v])
        return res

    def update_path(u, v, val):
        while head[u] != head[v]:
            if depth[head[u]] < depth[head[v]]:
                u, v = v, u
            seg.update_range(in_time[head[u]], in_time[u], val)
            u = parent[head[u]]
        if depth[u] > depth[v]:
            u, v = v, u
        seg.update_range(in_time[u], in_time[v], val)

    Q = int(input[ptr])
    ptr +=1
    ans =0
    for _ in range(Q):
        A = int(input[ptr])
        B = int(input[ptr+1])
        ptr +=2
        c = lca(A, B)
        distance = depth[A] + depth[B] - 2 * depth[c]
        nodes = distance +1
        sum_val = query_path(A, B)
        ans += sum_val + nodes
        update_path(A, B, 1)
    print(ans)

if __name__ == '__main__':
    main()
0