結果

問題 No.3442 Good Vertex Connectivity
コンテスト
ユーザー 👑 potato167
提出日時 2026-02-03 16:48:40
言語 PyPy3
(7.3.17)
結果
AC  
実行時間 2,421 ms / 3,000 ms
コード長 9,243 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 626 ms
コンパイル使用メモリ 82,332 KB
実行使用メモリ 208,608 KB
最終ジャッジ日時 2026-02-06 20:56:53
合計ジャッジ時間 123,738 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 69
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys
from array import array

# ------------------------------------------------------------
# Fast scanner (bytes -> int)
# ------------------------------------------------------------
class FastScanner:
    __slots__ = ("data", "i", "n")
    def __init__(self):
        self.data = sys.stdin.buffer.read()
        self.i = 0
        self.n = len(self.data)

    def int(self) -> int:
        data = self.data
        n = self.n
        i = self.i
        while i < n and data[i] <= 32:
            i += 1
        num = 0
        while i < n and data[i] > 32:
            num = num * 10 + (data[i] - 48)
            i += 1
        self.i = i
        return num


def solve():
    fs = FastScanner()
    N = fs.int()

    g = [[] for _ in range(N + 1)]
    for _ in range(N - 1):
        a = fs.int()
        b = fs.int()
        g[a].append(b)
        g[b].append(a)

    color = [0] * (N + 1)
    for i in range(1, N + 1):
        color[i] = fs.int()

    # -----------------------------
    # Rooted tree at 1:
    # tin/tout, depth, parent
    # Euler-in for subtree interval
    # Euler tour for RMQ-LCA: euler2, dep2
    # -----------------------------
    depth = [0] * (N + 1)
    parent = [0] * (N + 1)
    tin = [0] * (N + 1)
    tout = [0] * (N + 1)
    euler_in = [0] * N

    first = [-1] * (N + 1)
    euler2 = array('I')
    dep2 = array('I')

    timer = 0
    it = [0] * (N + 1)
    stack = [1]
    parent[1] = 0
    depth[1] = 0

    while stack:
        v = stack[-1]
        if it[v] == 0:
            tin[v] = timer
            euler_in[timer] = v
            timer += 1

            if first[v] == -1:
                first[v] = len(euler2)
            euler2.append(v)
            dep2.append(depth[v])

        if it[v] < len(g[v]):
            to = g[v][it[v]]
            it[v] += 1
            if to == parent[v]:
                continue
            parent[to] = v
            depth[to] = depth[v] + 1
            stack.append(to)
        else:
            tout[v] = timer - 1
            stack.pop()
            if stack:
                p = stack[-1]
                euler2.append(p)
                dep2.append(depth[p])

    M = len(euler2)  # ~ 2N-1

    # -----------------------------
    # Sparse Table for RMQ over dep2
    # store indices into euler2
    # -----------------------------
    logs = [0] * (M + 1)
    for i in range(2, M + 1):
        logs[i] = logs[i >> 1] + 1

    K = logs[M] + 1
    st = [None] * K
    st0 = array('I', range(M))
    st[0] = st0

    j = 1
    while (1 << j) <= M:
        prev = st[j - 1]
        span = 1 << (j - 1)
        new_len = M - (1 << j) + 1
        cur = array('I', [0]) * new_len
        # argmin depth
        for i in range(new_len):
            i1 = prev[i]
            i2 = prev[i + span]
            cur[i] = i1 if dep2[i1] <= dep2[i2] else i2
        st[j] = cur
        j += 1

    # local bindings for speed
    _first = first
    _logs = logs
    _st = st
    _euler2 = euler2
    _dep2 = dep2
    _depth = depth

    def lca(a: int, b: int) -> int:
        ia = _first[a]
        ib = _first[b]
        if ia > ib:
            ia, ib = ib, ia
        k = _logs[ib - ia + 1]
        t = _st[k]
        i1 = t[ia]
        i2 = t[ib - (1 << k) + 1]
        return _euler2[i1] if _dep2[i1] <= _dep2[i2] else _euler2[i2]

    def dist(a: int, b: int) -> int:
        c = lca(a, b)
        return _depth[a] + _depth[b] - 2 * _depth[c]

    def is_ancestor(a: int, b: int) -> bool:
        return tin[a] <= tin[b] and tout[b] <= tout[a]

    # -----------------------------
    # Binary lifting for jump_up only (use array('I') to reduce overhead)
    # -----------------------------
    LOG = (N).bit_length()
    up = [None] * LOG
    up0 = array('I', parent)  # parent[0..N]
    up[0] = up0
    for k in range(1, LOG):
        prev = up[k - 1]
        cur = array('I', [0]) * (N + 1)
        for v in range(1, N + 1):
            cur[v] = prev[prev[v]]
        up[k] = cur

    def jump_up(v: int, k: int) -> int:
        bit = 0
        while k:
            if k & 1:
                v = up[bit][v]
            k >>= 1
            bit += 1
        return v

    # -----------------------------
    # Segment Tree over Euler-in [0..N)
    # store as 4 arrays to avoid tuple allocations:
    # cnt, fv, lv, sd(sum distances consecutive)
    # Use 0 as "empty vertex" sentinel.
    # -----------------------------
    size = 1
    while size < N:
        size <<= 1
    segN = 2 * size

    cnt = array('I', [0]) * segN
    fv = array('I', [0]) * segN
    lv = array('I', [0]) * segN
    sd = array('Q', [0]) * segN  # 64-bit sum

    # build leaves
    base = size
    for i in range(N):
        v = euler_in[i]
        if color[v]:
            idx = base + i
            cnt[idx] = 1
            fv[idx] = v
            lv[idx] = v
            sd[idx] = 0

    # merge function inline for build/update/query
    for i in range(size - 1, 0, -1):
        L = i << 1
        R = L | 1
        cL = cnt[L]
        cR = cnt[R]
        if cL == 0:
            cnt[i] = cR
            fv[i] = fv[R]
            lv[i] = lv[R]
            sd[i] = sd[R]
        elif cR == 0:
            cnt[i] = cL
            fv[i] = fv[L]
            lv[i] = lv[L]
            sd[i] = sd[L]
        else:
            cnt[i] = cL + cR
            fv[i] = fv[L]
            lv[i] = lv[R]
            sd[i] = sd[L] + sd[R] + dist(lv[L], fv[R])

    def update_pos(pos: int, v: int, on: int):
        i = base + pos
        if on:
            cnt[i] = 1
            fv[i] = v
            lv[i] = v
            sd[i] = 0
        else:
            cnt[i] = 0
            fv[i] = 0
            lv[i] = 0
            sd[i] = 0
        i >>= 1
        while i:
            L = i << 1
            R = L | 1
            cL = cnt[L]
            cR = cnt[R]
            if cL == 0:
                cnt[i] = cR
                fv[i] = fv[R]
                lv[i] = lv[R]
                sd[i] = sd[R]
            elif cR == 0:
                cnt[i] = cL
                fv[i] = fv[L]
                lv[i] = lv[L]
                sd[i] = sd[L]
            else:
                cnt[i] = cL + cR
                fv[i] = fv[L]
                lv[i] = lv[R]
                sd[i] = sd[L] + sd[R] + dist(lv[L], fv[R])
            i >>= 1

    def query(l: int, r: int):
        # returns (c, f, t, s) as 4 ints (no tuple nodes stored in seg)
        cL = 0
        fL = 0
        tL = 0
        sL = 0
        cR = 0
        fR = 0
        tR = 0
        sR = 0

        l += base
        r += base
        while l < r:
            if l & 1:
                c = cnt[l]
                if c:
                    if cL == 0:
                        cL = c
                        fL = fv[l]
                        tL = lv[l]
                        sL = sd[l]
                    else:
                        sL = sL + sd[l] + dist(tL, fv[l])
                        tL = lv[l]
                        cL += c
                l += 1
            if r & 1:
                r -= 1
                c = cnt[r]
                if c:
                    if cR == 0:
                        cR = c
                        fR = fv[r]
                        tR = lv[r]
                        sR = sd[r]
                    else:
                        sR = sd[r] + sR + dist(lv[r], fR)
                        fR = fv[r]
                        cR += c
            l >>= 1
            r >>= 1

        if cL == 0:
            return cR, fR, tR, sR
        if cR == 0:
            return cL, fL, tL, sL
        return cL + cR, fL, tR, sL + sR + dist(tL, fR)

    def steiner_vertices(c: int, f: int, t: int, s: int) -> int:
        if c == 0:
            return 0
        if c == 1:
            return 1
        cycle = s + dist(t, f)
        return (cycle // 2) + 1

    def query_subtree(u: int) -> int:
        l = tin[u]
        r = tout[u] + 1
        c, f, t, s = query(l, r)
        return steiner_vertices(c, f, t, s)

    def query_complement_subtree(u: int) -> int:
        l = tin[u]
        r = tout[u] + 1
        c1, f1, t1, s1 = query(0, l)
        c2, f2, t2, s2 = query(r, N)
        if c1 == 0:
            return steiner_vertices(c2, f2, t2, s2)
        if c2 == 0:
            return steiner_vertices(c1, f1, t1, s1)
        c = c1 + c2
        f = f1
        t = t2
        s = s1 + s2 + dist(t1, f2)
        return steiner_vertices(c, f, t, s)

    # -----------------------------
    # Process queries
    # -----------------------------
    Q = fs.int()
    out = []
    append = out.append

    for _ in range(Q):
        t = fs.int()
        if t == 1:
            v = fs.int()
            color[v] ^= 1
            update_pos(tin[v], v, color[v])
        else:
            x = fs.int()
            y = fs.int()
            if x == y:
                c, f, t2_, s = query(0, N)
                append(str(steiner_vertices(c, f, t2_, s)))
            else:
                if is_ancestor(y, x):
                    z = jump_up(x, depth[x] - depth[y] - 1)
                    append(str(query_complement_subtree(z)))
                else:
                    append(str(query_subtree(y)))

    sys.stdout.write("\n".join(out))


if __name__ == "__main__":
    solve()
0