結果

問題 No.3442 Good Vertex Connectivity
コンテスト
ユーザー 👑 potato167
提出日時 2026-01-04 16:37:16
言語 PyPy3
(7.3.17)
結果
TLE  
実行時間 -
コード長 5,585 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 331 ms
コンパイル使用メモリ 82,816 KB
実行使用メモリ 109,412 KB
最終ジャッジ日時 2026-02-06 20:52:58
合計ジャッジ時間 23,533 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 15 TLE * 1 -- * 53
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys
sys.setrecursionlimit(1_000_000)

# ------------------------------------------------------------
# LCA (binary lifting) + Euler tour (iterative)
# ------------------------------------------------------------

def solve():
    input = sys.stdin.buffer.readline
    N = int(input())
    g = [[] for _ in range(N + 1)]
    for _ in range(N - 1):
        a, b = map(int, input().split())
        g[a].append(b)
        g[b].append(a)

    color = [0] + list(map(int, input().split()))
    LOG = (N).bit_length()

    parent = [0] * (N + 1)
    depth = [0] * (N + 1)
    tin = [0] * (N + 1)
    tout = [0] * (N + 1)
    euler = [0] * N

    up = [[0] * (N + 1) for _ in range(LOG)]

    # iterative DFS for tin/tout and parent/depth
    it = [0] * (N + 1)
    st = [1]
    parent[1] = 0
    depth[1] = 0
    timer = 0

    while st:
        v = st[-1]
        if it[v] == 0:
            tin[v] = timer
            euler[timer] = v
            timer += 1
        if it[v] < len(g[v]):
            to = g[v][it[v]]
            it[v] += 1
            if to == parent[v]:
                continue
            parent[to] = v
            depth[to] = depth[v] + 1
            st.append(to)
        else:
            tout[v] = timer - 1
            st.pop()

    for v in range(1, N + 1):
        up[0][v] = parent[v]
    for j in range(1, LOG):
        uj = up[j]
        ujm1 = up[j - 1]
        for v in range(1, N + 1):
            mid = ujm1[v]
            uj[v] = ujm1[mid] if mid else 0

    def is_ancestor(a: int, b: int) -> bool:
        return tin[a] <= tin[b] and tout[b] <= tout[a]

    def jump_up(v: int, k: int) -> int:
        j = 0
        while k:
            if k & 1:
                v = up[j][v]
            k >>= 1
            j += 1
        return v

    def lca(a: int, b: int) -> int:
        if is_ancestor(a, b):
            return a
        if is_ancestor(b, a):
            return b
        v = a
        for j in range(LOG - 1, -1, -1):
            nv = up[j][v]
            if nv and not is_ancestor(nv, b):
                v = nv
        return up[0][v]

    def dist(a: int, b: int) -> int:
        c = lca(a, b)
        return depth[a] + depth[b] - 2 * depth[c]

    # ------------------------------------------------------------
    # Segment Tree for "active nodes in Euler order"
    # Keep: cnt, first, last, sum(dist consecutive)
    # ------------------------------------------------------------

    class Agg:
        __slots__ = ("cnt", "first", "last", "sum")
        def __init__(self, cnt=0, first=-1, last=-1, s=0):
            self.cnt = cnt
            self.first = first
            self.last = last
            self.sum = s

    def merge(L: Agg, R: Agg) -> Agg:
        if L.cnt == 0:
            return R
        if R.cnt == 0:
            return L
        return Agg(
            L.cnt + R.cnt,
            L.first,
            R.last,
            L.sum + R.sum + dist(L.last, R.first)
        )

    size = 1
    while size < N:
        size <<= 1
    seg = [Agg() for _ in range(2 * size)]

    # build leaves
    for i in range(N):
        v = euler[i]
        if color[v] == 1:
            seg[size + i] = Agg(1, v, v, 0)
    for i in range(size - 1, 0, -1):
        seg[i] = merge(seg[i << 1], seg[i << 1 | 1])

    def update(pos: int, v_or_minus1: int):
        i = size + pos
        if v_or_minus1 == -1:
            seg[i] = Agg()
        else:
            v = v_or_minus1
            seg[i] = Agg(1, v, v, 0)
        i >>= 1
        while i:
            seg[i] = merge(seg[i << 1], seg[i << 1 | 1])
            i >>= 1

    def query(l: int, r: int) -> Agg:
        # [l, r)
        left = Agg()
        right = Agg()
        l += size
        r += size
        while l < r:
            if l & 1:
                left = merge(left, seg[l])
                l += 1
            if r & 1:
                r -= 1
                right = merge(seg[r], right)
            l >>= 1
            r >>= 1
        return merge(left, right)

    def steiner_vertices(agg: Agg) -> int:
        if agg.cnt == 0:
            return 0
        if agg.cnt == 1:
            return 1
        cycle = agg.sum + dist(agg.last, agg.first)
        edges = cycle // 2
        return edges + 1

    def query_subtree(u: int) -> int:
        l = tin[u]
        r = tout[u] + 1
        return steiner_vertices(query(l, r))

    def query_complement_subtree(u: int) -> int:
        l = tin[u]
        r = tout[u] + 1
        left = query(0, l)
        right = query(r, N)
        return steiner_vertices(merge(left, right))

    # ------------------------------------------------------------
    # Process Queries
    # ------------------------------------------------------------
    Q = int(input())
    out = []
    for _ in range(Q):
        parts = input().split()
        t = int(parts[0])
        if t == 1:
            v = int(parts[1])
            color[v] ^= 1
            if color[v] == 1:
                update(tin[v], v)
            else:
                update(tin[v], -1)
        else:
            x = int(parts[1])
            y = int(parts[2])

            if x == y:
                out.append(str(steiner_vertices(query(0, N))))
                continue

            if is_ancestor(y, x):
                # z = child of y on path to x
                z = jump_up(x, depth[x] - depth[y] - 1)
                out.append(str(query_complement_subtree(z)))
            else:
                # S = subtree(y)
                out.append(str(query_subtree(y)))

    sys.stdout.write("\n".join(out))

if __name__ == "__main__":
    solve()
0