結果

問題 No.3442 Good Vertex Connectivity
コンテスト
ユーザー kidodesu
提出日時 2026-02-06 23:06:17
言語 PyPy3
(7.3.17)
結果
WA  
実行時間 -
コード長 11,732 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 339 ms
コンパイル使用メモリ 82,748 KB
実行使用メモリ 135,508 KB
最終ジャッジ日時 2026-02-06 23:06:58
合計ジャッジ時間 40,655 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 1 WA * 4 TLE * 8 -- * 56
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

# https://github.com/tatyam-prime/SortedSet/blob/main/SortedMultiset.py
import math
from bisect import bisect_left, bisect_right
class SortedMultiset():
    BUCKET_RATIO = 16
    SPLIT_RATIO = 24
    
    def __init__(self, a = []) -> None:
        "Make a new SortedMultiset from iterable. / O(N) if sorted / O(N log N)"
        a = list(a)
        n = self.size = len(a)
        if any(a[i] > a[i + 1] for i in range(n - 1)):
            a.sort()
        num_bucket = int(math.ceil(math.sqrt(n / self.BUCKET_RATIO)))
        self.a = [a[n * i // num_bucket : n * (i + 1) // num_bucket] for i in range(num_bucket)]

    def __iter__(self):
        for i in self.a:
            for j in i: yield j

    def __reversed__(self):
        for i in reversed(self.a):
            for j in reversed(i): yield j
    
    def __eq__(self, other) -> bool:
        return list(self) == list(other)
    
    def __len__(self) -> int:
        return self.size
    
    def __repr__(self) -> str:
        return "SortedMultiset" + str(self.a)
    
    def __str__(self) -> str:
        s = str(list(self))
        return "{" + s[1 : len(s) - 1] + "}"

    def _position(self, x):
        "return the bucket, index of the bucket and position in which x should be. self must not be empty."
        for i, a in enumerate(self.a):
            if x <= a[-1]: break
        return (a, i, bisect_left(a, x))

    def __contains__(self, x):
        if self.size == 0: return False
        a, _, i = self._position(x)
        return i != len(a) and a[i] == x

    def count(self, x) -> int:
        "Count the number of x."
        return self.index_right(x) - self.index(x)

    def add(self, x) -> None:
        "Add an element. / O(√N)"
        if self.size == 0:
            self.a = [[x]]
            self.size = 1
            return
        a, b, i = self._position(x)
        a.insert(i, x)
        self.size += 1
        if len(a) > len(self.a) * self.SPLIT_RATIO:
            mid = len(a) >> 1
            self.a[b:b+1] = [a[:mid], a[mid:]]
    
    def _pop(self, a, b: int, i: int):
        ans = a.pop(i)
        self.size -= 1
        if not a: del self.a[b]
        return ans

    def discard(self, x) -> bool:
        "Remove an element and return True if removed. / O(√N)"
        if self.size == 0: return False
        a, b, i = self._position(x)
        if i == len(a) or a[i] != x: return False
        self._pop(a, b, i)
        return True

    def lt(self, x):
        "Find the largest element < x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] < x:
                return a[bisect_left(a, x) - 1]

    def le(self, x):
        "Find the largest element <= x, or None if it doesn't exist."
        for a in reversed(self.a):
            if a[0] <= x:
                return a[bisect_right(a, x) - 1]

    def gt(self, x):
        "Find the smallest element > x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] > x:
                return a[bisect_right(a, x)]

    def ge(self, x):
        "Find the smallest element >= x, or None if it doesn't exist."
        for a in self.a:
            if a[-1] >= x:
                return a[bisect_left(a, x)]
    
    def __getitem__(self, i: int):
        "Return the i-th element."
        if i < 0:
            for a in reversed(self.a):
                i += len(a)
                if i >= 0: return a[i]
        else:
            for a in self.a:
                if i < len(a): return a[i]
                i -= len(a)
        raise IndexError
    
    def pop(self, i: int = -1):
        "Pop and return the i-th element."
        if i < 0:
            for b, a in enumerate(reversed(self.a)):
                i += len(a)
                if i >= 0: return self._pop(a, ~b, i)
        else:
            for b, a in enumerate(self.a):
                if i < len(a): return self._pop(a, b, i)
                i -= len(a)
        raise IndexError

    def index(self, x) -> int:
        "Count the number of elements < x."
        ans = 0
        for a in self.a:
            if a[-1] >= x:
                return ans + bisect_left(a, x)
            ans += len(a)
        return ans

    def index_right(self, x) -> int:
        "Count the number of elements <= x."
        ans = 0
        for a in self.a:
            if a[-1] > x:
                return ans + bisect_right(a, x)
            ans += len(a)
        return ans

class segtree:
    n = 1
    size = 1
    log = 2
    d = [0]
    op = None
    e = 10**15

    def __init__(self, V, OP, E):
        self.n = len(V)
        self.op = OP
        self.e = E
        self.log = (self.n - 1).bit_length()
        self.size = 1 << self.log
        self.d = [E for i in range(2 * self.size)]
        for i in range(self.n):
            self.d[self.size + i] = V[i]
        for i in range(self.size - 1, 0, -1):
            self.update(i)

    def set(self, p, x):
        assert 0 <= p and p < self.n
        p += self.size
        self.d[p] = x
        for i in range(1, self.log + 1):
            self.update(p >> i)

    def get(self, p):
        assert 0 <= p and p < self.n
        return self.d[p + self.size]

    def prod(self, l, r):
        assert 0 <= l and l <= r and r <= self.n
        sml = self.e
        smr = self.e
        l += self.size
        r += self.size
        while l < r:
            if l & 1:
                sml = self.op(sml, self.d[l])
                l += 1
            if r & 1:
                smr = self.op(self.d[r - 1], smr)
                r -= 1
            l >>= 1
            r >>= 1
        return self.op(sml, smr)

    def all_prod(self):
        return self.d[1]

    def max_right(self, l, f):
        assert 0 <= l and l <= self.n
        assert f(self.e)
        if l == self.n:
            return self.n
        l += self.size
        sm = self.e
        while 1:
            while l % 2 == 0:
                l >>= 1
            if not (f(self.op(sm, self.d[l]))):
                while l < self.size:
                    l = 2 * l
                    if f(self.op(sm, self.d[l])):
                        sm = self.op(sm, self.d[l])
                        l += 1
                return l - self.size
            sm = self.op(sm, self.d[l])
            l += 1
            if (l & -l) == l:
                break
        return self.n

    def min_left(self, r, f):
        assert 0 <= r and r <= self.n
        assert f(self.e)
        if r == 0:
            return 0
        r += self.size
        sm = self.e
        while 1:
            r -= 1
            while r > 1 and (r % 2):
                r >>= 1
            if not (f(self.op(self.d[r], sm))):
                while r < self.size:
                    r = 2 * r + 1
                    if f(self.op(self.d[r], sm)):
                        sm = self.op(self.d[r], sm)
                        r -= 1
                return r + 1 - self.size
            sm = self.op(self.d[r], sm)
            if (r & -r) == r:
                break
        return 0

    def update(self, k):
        self.d[k] = self.op(self.d[2 * k], self.d[2 * k + 1])

    def __str__(self):
        return str([self.get(i) for i in range(self.n)])

n = int(input())
node = [[] for _ in range(n)]
for _ in range(n-1):
    u, v = list(map(lambda x: int(x)-1, input().split()))
    node[u].append(v)
    node[v].append(u)

def dubling(): #木のk番目の上の頂点を管理
    dub = [[-1 for _ in range(n)] for _ in range(30)] #dub[i][j]:頂点jから2**i移動した点
    for i in range(n):
        if P[i] != -1: dub[0][i] = P[i]

    for i in range(1, 30):
        for j in range(n):
            if dub[i-1][j] == -1:
                dub[i][j] = -1
            else:
                dub[i][j] = dub[i-1][dub[i-1][j]]
    
    return dub

def move_point(now, k): #dubling配列を用いて、点nowから深さがk小さい点を出力
    if D[now] < k:
        return -1
    for i in range(29, -1, -1):
        if k >> i & 1:
            now = dub[i][now]
    return now


P = [-1] * n
D = [-1] * n
D[0] = 0
S = [0]
while S:
    now = S.pop()
    for nxt in node[now]:
        if D[nxt] == -1:
            P[nxt] = now
            D[nxt] = D[now] + 1
            S.append(nxt)

dub = dubling()

def LCA(u, v):
    du = D[u]
    dv = D[v]
    if du > dv:
        u, v = v, u
        du, dv = dv, du
    a = dv - du
    now = 0
    while a:
        if a & 1<<now:
            v = dub[now][v]
            a -= 1<<now
        now += 1
    if u == v: return u
    for i in range(29, -1, -1):
        if dub[i][u] != dub[i][v]:
            u, v = dub[i][u], dub[i][v]
    return dub[0][u]

def dist(u, v):
    p = LCA(u, v)
    return D[u] + D[v] - 2 * D[p]

C = list(map(int, input().split()))
S = [0]
M = []
F = [[-1, -1] for _ in range(n)]
E = [0] * 2*n
st = segtree([0] * (2*n), lambda a, b: a + b, 0)
sl = SortedMultiset([])
while S:
    now = S.pop()
    mow = now if now >= 0 else ~now
    if F[mow][0] == -1:
        F[mow][0] = len(M)
    F[mow][1] = len(M)
    M.append(mow)
    if now >= 0:
        for nxt in node[now]:
            if P[now] != nxt:
                S.append(~now)
                S.append(nxt)

pre = -1
pi = -1
for i in range(2*n-2):
    now = M[i]
    if C[now] and i in F[now]:
        if pre != -1:
            st.set(i, dist(pre, now))
        else:
            pi = now
        sl.add(i)
        pre = now

if pi != -1:
    st.set(pi, dist(M[pi], pre))

#print(F)
#print(M)
for _ in range(int(input())):
    T = list(map(int, input().split()))
    if T[0] == 1:
        u = T[1]-1
        if not C[u]:
            for v in set(F[u]):
                sl.add(v)
                idx = sl.index(v)
                pu = sl[idx-1]
                pv = sl[(idx+1) % len(sl)]
                st.set(v, dist(M[pu], M[v]))
                st.set(pv, dist(M[v], M[pv]))
        else:
            for v in set(F[u]):
                sl.discard(v)
                if sl:
                    idx = sl.index(v)
                    pu = sl[idx-1]
                    pv = sl[idx % len(sl)]
                    st.set(v, 0)
                    st.set(pv, dist(M[pu], M[pv]))
        C[u] ^= 1
    else:
        u, v = T[1:]
        u, v = u-1, v-1
        if u == v:
            fs, ft = F[v]
            f = 1
        else:
            p = LCA(u, v)
            if p == v:
                v = move_point(u, D[u] - D[v] - 1)
                f = 0
            else:
                f = 1
            fs, ft = F[v]
        if f:
            if sl:
                idx0 = sl.index(fs)
                idx1 = sl.index_right(ft) - 1
                #print("Hi", sl, idx0, idx1)
                if idx0 > idx1:
                    print(0)
                else:
                    ans = st.prod(sl[idx0], sl[idx1]+1)
                    ans += dist(M[sl[idx0]], M[sl[idx1]]) - st.get(sl[idx0])
                    print(ans//2+1)
            else:
                print(0)
        else:
            if sl:
                idx0 = sl.index(fs)
                idx1 = sl.index_right(ft)
                #print(sl, idx0, idx1)
                if (idx0, idx1) == (0, len(sl)):
                    print(0)
                elif idx0 == idx1:
                    print(st.prod(0, 2*n) // 2 + 1)
                else:
                    if idx1 == len(sl):
                        ans = st.prod(0, 2*n) - st.prod(sl[idx0], 2*n)
                    else:
                        ans = st.prod(0, 2*n) - st.prod(sl[idx0], sl[idx1])
                    ans += dist(M[sl[(idx0-1)%len(sl)]], M[sl[idx1%len(sl)]]) - st.get(sl[idx1%len(sl)])
                    print(ans//2+1)
            else:
                print(0)
        
        #print(*[st.get(i) for i in range(2*n)])
0