結果

問題 No.3442 Good Vertex Connectivity
コンテスト
ユーザー 👑 potato167
提出日時 2026-01-04 16:37:53
言語 PyPy3
(7.3.17)
結果
AC  
実行時間 2,569 ms / 3,000 ms
コード長 7,054 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 478 ms
コンパイル使用メモリ 82,356 KB
実行使用メモリ 279,984 KB
最終ジャッジ日時 2026-02-06 20:54:38
合計ジャッジ時間 119,632 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 69
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys
from array import array

# ------------------------------------------------------------
# RMQ-LCA (Euler tour + Sparse Table for depth-minimum)
#   - LCA: O(1)
#   - dist: O(1)
# Segment Tree over Euler-in order for active (black) nodes:
#   - update/query: O(log N)
# Overall: O((N+Q) log N) time, O(N log N) memory (Sparse Table)
# ------------------------------------------------------------

def solve():
    input = sys.stdin.buffer.readline
    N = int(input())
    g = [[] for _ in range(N + 1)]
    for _ in range(N - 1):
        a, b = map(int, input().split())
        g[a].append(b)
        g[b].append(a)

    color = [0] + list(map(int, input().split()))

    # -----------------------------
    # Rooted tree at 1: depth[], tin/tout, euler_in (for subtree interval)
    # + Euler tour for RMQ-LCA: euler2 (len ~ 2N-1), dep2
    # -----------------------------
    depth = [0] * (N + 1)
    parent = [0] * (N + 1)
    tin = [0] * (N + 1)
    tout = [0] * (N + 1)
    euler_in = [0] * N

    # For RMQ-LCA
    first = [-1] * (N + 1)
    euler2 = array('I')  # vertex ids
    dep2 = array('I')    # depths along euler2

    timer = 0
    it = [0] * (N + 1)
    stack = [1]
    parent[1] = 0
    depth[1] = 0

    # Iterative DFS generating:
    # - tin/tout and euler_in (entry order)
    # - full Euler tour for RMQ-LCA (append on entry and after returning from child)
    while stack:
        v = stack[-1]
        if it[v] == 0:
            tin[v] = timer
            euler_in[timer] = v
            timer += 1

            if first[v] == -1:
                first[v] = len(euler2)
            euler2.append(v)
            dep2.append(depth[v])

        if it[v] < len(g[v]):
            to = g[v][it[v]]
            it[v] += 1
            if to == parent[v]:
                continue
            parent[to] = v
            depth[to] = depth[v] + 1
            stack.append(to)
        else:
            tout[v] = timer - 1
            stack.pop()
            if stack:
                p = stack[-1]
                # append parent again when returning
                euler2.append(p)
                dep2.append(depth[p])

    M = len(euler2)  # ~ 2N-1

    # -----------------------------
    # Sparse Table for RMQ over dep2 (store indices into euler2)
    # -----------------------------
    logs = array('I', [0]) * (M + 1)
    for i in range(2, M + 1):
        logs[i] = logs[i >> 1] + 1

    K = logs[M] + 1
    st = [None] * K

    # level 0: indices 0..M-1
    st0 = array('I', range(M))
    st[0] = st0

    j = 1
    length = M
    while (1 << j) <= M:
        prev = st[j - 1]
        span = 1 << (j - 1)
        new_len = M - (1 << j) + 1
        cur = array('I', [0]) * new_len
        # cur[i] = argmin(dep2[prev[i]], dep2[prev[i+span]])
        for i in range(new_len):
            i1 = prev[i]
            i2 = prev[i + span]
            cur[i] = i1 if dep2[i1] <= dep2[i2] else i2
        st[j] = cur
        j += 1

    def lca(a: int, b: int) -> int:
        ia = first[a]
        ib = first[b]
        if ia > ib:
            ia, ib = ib, ia
        k = logs[ib - ia + 1]
        i1 = st[k][ia]
        i2 = st[k][ib - (1 << k) + 1]
        return euler2[i1] if dep2[i1] <= dep2[i2] else euler2[i2]

    def dist(a: int, b: int) -> int:
        c = lca(a, b)
        return depth[a] + depth[b] - 2 * depth[c]

    def is_ancestor(a: int, b: int) -> bool:
        return tin[a] <= tin[b] and tout[b] <= tout[a]

    # jump_up via parent pointers in O(height) is too slow; use binary lifting for this part only
    # (needed to find child z on path y->x when y is ancestor of x)
    LOG = (N).bit_length()
    up = [[0] * (N + 1) for _ in range(LOG)]
    up0 = up[0]
    for v in range(1, N + 1):
        up0[v] = parent[v]
    for k in range(1, LOG):
        cur = up[k]
        prev = up[k - 1]
        for v in range(1, N + 1):
            mid = prev[v]
            cur[v] = prev[mid] if mid else 0

    def jump_up(v: int, k: int) -> int:
        bit = 0
        while k:
            if k & 1:
                v = up[bit][v]
            k >>= 1
            bit += 1
        return v

    # -----------------------------
    # Segment Tree over euler_in positions [0..N)
    # Agg = (cnt, first_vertex, last_vertex, sum_dist_consecutive)
    # -----------------------------
    def merge(L, R):
        if L[0] == 0:
            return R
        if R[0] == 0:
            return L
        return (L[0] + R[0], L[1], R[2], L[3] + R[3] + dist(L[2], R[1]))

    def steiner_vertices(agg) -> int:
        cnt, fv, lv, s = agg
        if cnt == 0:
            return 0
        if cnt == 1:
            return 1
        cycle = s + dist(lv, fv)
        edges = cycle // 2
        return edges + 1

    size = 1
    while size < N:
        size <<= 1
    seg = [(0, -1, -1, 0)] * (2 * size)

    # build leaves
    for i in range(N):
        v = euler_in[i]
        if color[v] == 1:
            seg[size + i] = (1, v, v, 0)
    for i in range(size - 1, 0, -1):
        seg[i] = merge(seg[i << 1], seg[i << 1 | 1])

    def update(pos: int, v_or_minus1: int):
        i = size + pos
        if v_or_minus1 == -1:
            seg[i] = (0, -1, -1, 0)
        else:
            v = v_or_minus1
            seg[i] = (1, v, v, 0)
        i >>= 1
        while i:
            seg[i] = merge(seg[i << 1], seg[i << 1 | 1])
            i >>= 1

    def query(l: int, r: int):
        # [l, r)
        left = (0, -1, -1, 0)
        right = (0, -1, -1, 0)
        l += size
        r += size
        while l < r:
            if l & 1:
                left = merge(left, seg[l])
                l += 1
            if r & 1:
                r -= 1
                right = merge(seg[r], right)
            l >>= 1
            r >>= 1
        return merge(left, right)

    def query_subtree(u: int) -> int:
        l = tin[u]
        r = tout[u] + 1
        return steiner_vertices(query(l, r))

    def query_complement_subtree(u: int) -> int:
        l = tin[u]
        r = tout[u] + 1
        left = query(0, l)
        right = query(r, N)
        return steiner_vertices(merge(left, right))

    # -----------------------------
    # Process Queries
    # -----------------------------
    Q = int(input())
    out = []
    for _ in range(Q):
        parts = input().split()
        t = int(parts[0])
        if t == 1:
            v = int(parts[1])
            color[v] ^= 1
            if color[v] == 1:
                update(tin[v], v)
            else:
                update(tin[v], -1)
        else:
            x = int(parts[1])
            y = int(parts[2])

            if x == y:
                out.append(str(steiner_vertices(query(0, N))))
                continue

            if is_ancestor(y, x):
                # z: child of y on path to x
                z = jump_up(x, depth[x] - depth[y] - 1)
                out.append(str(query_complement_subtree(z)))
            else:
                out.append(str(query_subtree(y)))

    sys.stdout.write("\n".join(out))


if __name__ == "__main__":
    solve()
0