結果

問題 No.3309 Aging Railway
コンテスト
ユーザー yu23578
提出日時 2025-09-18 20:15:46
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 995 ms / 3,000 ms
コード長 2,918 bytes
コンパイル時間 287 ms
コンパイル使用メモリ 82,476 KB
実行使用メモリ 220,504 KB
最終ジャッジ日時 2025-10-09 22:41:43
合計ジャッジ時間 13,821 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
other AC * 20
権限があれば一括ダウンロードができます

ソースコード

diff #

class UnionFind:
    def __init__(self):
        self.Nodes = 0
        self.par = []
        self.siz = []

    # 頂点数の設定
    def init(self, N: int):
        assert 0 <= N
        for _ in range(self.Nodes, N):
            self.par.append(-1)
            self.siz.append(1)
        self.Nodes = N

    # リーダーの取得
    def leader(self, x: int) -> int:
        assert 0 <= x < self.Nodes
        if self.par[x] == -1:
            return x
        self.par[x] = self.leader(self.par[x])
        return self.par[x]

    # aとbが同じ連結成分か判定
    def same(self, a: int, b: int) -> bool:
        assert 0 <= a < self.Nodes
        assert 0 <= b < self.Nodes
        return self.leader(a) == self.leader(b)

    # aとbをつなぐ
    def merge(self, a: int, b: int) -> bool:
        assert 0 <= a < self.Nodes
        assert 0 <= b < self.Nodes
        pa = self.leader(a)
        pb = self.leader(b)
        res = (pa == pb)
        if not res:
            if self.siz[pa] < self.siz[pb]:
                self.par[pa] = pb
                self.siz[pb] += self.siz[pa]
            else:
                self.par[pb] = pa
                self.siz[pa] += self.siz[pb]
        return res

    # 同じ連結成分のサイズ
    def size(self, a: int) -> int:
        assert 0 <= a < self.Nodes
        return self.siz[self.leader(a)]

    # 同じ連結成分内の頂点(O(N))
    def same_node(self, a: int):
        assert 0 <= a < self.Nodes
        return [i for i in range(self.Nodes) if self.same(a, i)]

    # 連結成分ごとのグループを返す
    def groups(self):
        res = [[] for _ in range(self.Nodes)]
        for i in range(self.Nodes):
            res[self.leader(i)].append(i)
        return [g for g in res if g]


# ---- main ----
def main():
    import sys
    input = sys.stdin.readline

    N, M = map(int, input().split())
    edges = [tuple(map(int, input().split())) for _ in range(N - 1)]
    edges = [(u - 1, v - 1) for u, v in edges]

    d = UnionFind()
    d.init(N)

    # 各段階のUnionFindを保持
    E = [None] * (N - 1)
    for i in reversed(range(N - 1)):
        d.merge(edges[i][0], edges[i][1])
        # Pythonでは参照コピーなので、ディープコピーを作る
        uf_copy = UnionFind()
        uf_copy.Nodes = d.Nodes
        uf_copy.par = d.par.copy()
        uf_copy.siz = d.siz.copy()
        E[i] = uf_copy

    ans = [0] * (N - 1)
    ans[0] = M

    for _ in range(M):
        s, t = map(int, input().split())
        s -= 1
        t -= 1
        ok = 0
        ng = N - 1
        while abs(ok - ng) > 1:
            mid = (ok + ng) // 2
            if E[mid].same(s, t):
                ok = mid
            else:
                ng = mid
        ans[ok] -= 1

    for i in range(1, N - 1):
        ans[i] += ans[i - 1]

    for a in ans:
        print(a)


if __name__ == "__main__":
    main()
0