結果

問題 No.399 動的な領主
ユーザー kept1994kept1994
提出日時 2023-06-17 16:17:19
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 750 ms / 2,000 ms
コード長 4,146 bytes
コンパイル時間 327 ms
コンパイル使用メモリ 81,920 KB
実行使用メモリ 191,104 KB
最終ジャッジ日時 2024-06-25 06:39:19
合計ジャッジ時間 10,022 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 40 ms
52,736 KB
testcase_01 AC 39 ms
52,736 KB
testcase_02 AC 43 ms
53,632 KB
testcase_03 AC 42 ms
53,760 KB
testcase_04 AC 99 ms
76,672 KB
testcase_05 AC 243 ms
82,304 KB
testcase_06 AC 725 ms
107,924 KB
testcase_07 AC 693 ms
107,264 KB
testcase_08 AC 672 ms
106,112 KB
testcase_09 AC 681 ms
107,008 KB
testcase_10 AC 127 ms
77,676 KB
testcase_11 AC 211 ms
80,640 KB
testcase_12 AC 581 ms
105,904 KB
testcase_13 AC 576 ms
105,600 KB
testcase_14 AC 447 ms
190,336 KB
testcase_15 AC 458 ms
191,104 KB
testcase_16 AC 484 ms
153,216 KB
testcase_17 AC 750 ms
107,520 KB
testcase_18 AC 689 ms
106,752 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#!/usr/bin/env python3
import sys
sys.setrecursionlimit(10 ** 9)

class LcaDoubling:
    # 木であれば任意の点を根と見做せる。
    def __init__(self, N, root=0):
        self.N = N
        self.root = root
        self.G = [[] for _ in range(N)]
        self.depths = [-1] * N
        self.distances = [-1] * N
        self.ancestors = [] # ダブリングによって求めた祖先の配列の配列 i番目の配列は過去ノードの2^i個祖先のノードを格納する。
        return
    
    def addEdge(self, fromNode: int, toNode: int, cost: int):
        self.G[fromNode].append((cost, toNode))
        return
    
    def build(self):
        """
        O(NlogN)
        """
        prevAncestors = self._bfs()
        self.ancestors.append(prevAncestors)
        d = 1
        max_depth = max(self.depths)
        while d < max_depth:
            nextAncestors = [prevAncestors[p] for p in prevAncestors]
            self.ancestors.append(nextAncestors)
            d <<= 1
            prevAncestors = nextAncestors
        return

    def _bfs(self):
        q = [(self.root, -1, 0, 0)]
        directAncestors = [-1] * (self.N + 1)  # 頂点数より1個長くし、存在しないことを-1で表す。末尾(-1)要素は常に-1
        while q:
            now, parent, dep, dist = q.pop()
            directAncestors[now] = parent
            self.depths[now] = dep
            self.distances[now] = dist
            for cost, next in self.G[now]:
                if next != parent:
                    q.append((next, now, dep + 1, dist + cost))
        return directAncestors
 
    def getLca(self, nodeA: int, nodeB: int):
        """
        O(logN)
        """
        depthA, depthB = self.depths[nodeA], self.depths[nodeB]
        if depthA > depthB:
            nodeA, nodeB = nodeB, nodeA
            depthA, depthB = depthB, depthA
        
        # 2ノードを同じ深さまで揃える。
        tu = nodeA
        tv = self.upstream(nodeB, depthB - depthA)

        # 遡上させて行き2つが衝突する位置が共通祖先。
        if nodeA == tv:
            return nodeA
        for k in range(depthA.bit_length() - 1, -1, -1):
            mu = self.ancestors[k][tu]
            mv = self.ancestors[k][tv]
            if mu != mv:
                tu = mu
                tv = mv
        lca = self.ancestors[0][tu]
        assert lca == self.ancestors[0][tv]
        return lca
 
    # 2つのノードの間の距離を返す。
    def getDistance(self, nodeA, nodeB):
        """
        O(logN)
        """
        lca = self.getLca(nodeA, nodeB)
        return self.distances[nodeA] + self.distances[nodeB] - 2 * self.distances[lca]

    # targetNodeが2つのノード間のパス上に存在するかを返す。
    def isOnPath(self, nodeA: int, nodeB: int, evalNode: int):
        """
        O(logN)
        """
        return self.getDistance(nodeA, nodeB) == self.getDistance(nodeA, evalNode) + self.getDistance(evalNode, nodeB) 

    # ノードvからk個遡上したノードを返す。
    def upstream(self, v, k):
        i = 0
        while k:
            if k & 1:
                v = self.ancestors[i][v]
            k >>= 1
            i += 1
        return v

def main():
    N = int(input())
    ld = LcaDoubling(N)
    for _ in range(N - 1):
        u, v = map(int, input().split())
        ld.addEdge(u - 1, v - 1, 1)
        ld.addEdge(v - 1, u - 1, 1)
    ld.build()
    imos = [0] * N

    def dfs(pre: int, now: int):
        for _, next in ld.G[now]: 
            if next == pre: 
                continue
            dfs(now, next)
        if pre != -1:
            imos[pre] += imos[now]
        return

    Q = int(input())
    for i in range(Q):
        a, b = map(int, input().split())
        imos[a - 1] += 1
        imos[b - 1] += 1
        lc = ld.getLca(a - 1, b - 1)
        imos[lc] -= 1
        p_lc = ld.upstream(lc, 1)
        if p_lc != -1:
            imos[p_lc] -= 1
    dfs(-1, 0)
    ans = 0
    for num in imos:
        ans += num * (num + 1) // 2
    print(ans)
    return
        
if __name__ == '__main__':
    main()
0