import sys from array import array # ------------------------------------------------------------ # Fast scanner (bytes -> int) # ------------------------------------------------------------ class FastScanner: __slots__ = ("data", "i", "n") def __init__(self): self.data = sys.stdin.buffer.read() self.i = 0 self.n = len(self.data) def int(self) -> int: data = self.data n = self.n i = self.i while i < n and data[i] <= 32: i += 1 num = 0 while i < n and data[i] > 32: num = num * 10 + (data[i] - 48) i += 1 self.i = i return num def solve(): fs = FastScanner() N = fs.int() g = [[] for _ in range(N + 1)] for _ in range(N - 1): a = fs.int() b = fs.int() g[a].append(b) g[b].append(a) color = [0] * (N + 1) for i in range(1, N + 1): color[i] = fs.int() # ----------------------------- # Rooted tree at 1: # tin/tout, depth, parent # Euler-in for subtree interval # Euler tour for RMQ-LCA: euler2, dep2 # ----------------------------- depth = [0] * (N + 1) parent = [0] * (N + 1) tin = [0] * (N + 1) tout = [0] * (N + 1) euler_in = [0] * N first = [-1] * (N + 1) euler2 = array('I') dep2 = array('I') timer = 0 it = [0] * (N + 1) stack = [1] parent[1] = 0 depth[1] = 0 while stack: v = stack[-1] if it[v] == 0: tin[v] = timer euler_in[timer] = v timer += 1 if first[v] == -1: first[v] = len(euler2) euler2.append(v) dep2.append(depth[v]) if it[v] < len(g[v]): to = g[v][it[v]] it[v] += 1 if to == parent[v]: continue parent[to] = v depth[to] = depth[v] + 1 stack.append(to) else: tout[v] = timer - 1 stack.pop() if stack: p = stack[-1] euler2.append(p) dep2.append(depth[p]) M = len(euler2) # ~ 2N-1 # ----------------------------- # Sparse Table for RMQ over dep2 # store indices into euler2 # ----------------------------- logs = [0] * (M + 1) for i in range(2, M + 1): logs[i] = logs[i >> 1] + 1 K = logs[M] + 1 st = [None] * K st0 = array('I', range(M)) st[0] = st0 j = 1 while (1 << j) <= M: prev = st[j - 1] span = 1 << (j - 1) new_len = M - (1 << j) + 1 cur = array('I', [0]) * new_len # argmin depth for i in range(new_len): i1 = prev[i] i2 = prev[i + span] cur[i] = i1 if dep2[i1] <= dep2[i2] else i2 st[j] = cur j += 1 # local bindings for speed _first = first _logs = logs _st = st _euler2 = euler2 _dep2 = dep2 _depth = depth def lca(a: int, b: int) -> int: ia = _first[a] ib = _first[b] if ia > ib: ia, ib = ib, ia k = _logs[ib - ia + 1] t = _st[k] i1 = t[ia] i2 = t[ib - (1 << k) + 1] return _euler2[i1] if _dep2[i1] <= _dep2[i2] else _euler2[i2] def dist(a: int, b: int) -> int: c = lca(a, b) return _depth[a] + _depth[b] - 2 * _depth[c] def is_ancestor(a: int, b: int) -> bool: return tin[a] <= tin[b] and tout[b] <= tout[a] # ----------------------------- # Binary lifting for jump_up only (use array('I') to reduce overhead) # ----------------------------- LOG = (N).bit_length() up = [None] * LOG up0 = array('I', parent) # parent[0..N] up[0] = up0 for k in range(1, LOG): prev = up[k - 1] cur = array('I', [0]) * (N + 1) for v in range(1, N + 1): cur[v] = prev[prev[v]] up[k] = cur def jump_up(v: int, k: int) -> int: bit = 0 while k: if k & 1: v = up[bit][v] k >>= 1 bit += 1 return v # ----------------------------- # Segment Tree over Euler-in [0..N) # store as 4 arrays to avoid tuple allocations: # cnt, fv, lv, sd(sum distances consecutive) # Use 0 as "empty vertex" sentinel. # ----------------------------- size = 1 while size < N: size <<= 1 segN = 2 * size cnt = array('I', [0]) * segN fv = array('I', [0]) * segN lv = array('I', [0]) * segN sd = array('Q', [0]) * segN # 64-bit sum # build leaves base = size for i in range(N): v = euler_in[i] if color[v]: idx = base + i cnt[idx] = 1 fv[idx] = v lv[idx] = v sd[idx] = 0 # merge function inline for build/update/query for i in range(size - 1, 0, -1): L = i << 1 R = L | 1 cL = cnt[L] cR = cnt[R] if cL == 0: cnt[i] = cR fv[i] = fv[R] lv[i] = lv[R] sd[i] = sd[R] elif cR == 0: cnt[i] = cL fv[i] = fv[L] lv[i] = lv[L] sd[i] = sd[L] else: cnt[i] = cL + cR fv[i] = fv[L] lv[i] = lv[R] sd[i] = sd[L] + sd[R] + dist(lv[L], fv[R]) def update_pos(pos: int, v: int, on: int): i = base + pos if on: cnt[i] = 1 fv[i] = v lv[i] = v sd[i] = 0 else: cnt[i] = 0 fv[i] = 0 lv[i] = 0 sd[i] = 0 i >>= 1 while i: L = i << 1 R = L | 1 cL = cnt[L] cR = cnt[R] if cL == 0: cnt[i] = cR fv[i] = fv[R] lv[i] = lv[R] sd[i] = sd[R] elif cR == 0: cnt[i] = cL fv[i] = fv[L] lv[i] = lv[L] sd[i] = sd[L] else: cnt[i] = cL + cR fv[i] = fv[L] lv[i] = lv[R] sd[i] = sd[L] + sd[R] + dist(lv[L], fv[R]) i >>= 1 def query(l: int, r: int): # returns (c, f, t, s) as 4 ints (no tuple nodes stored in seg) cL = 0 fL = 0 tL = 0 sL = 0 cR = 0 fR = 0 tR = 0 sR = 0 l += base r += base while l < r: if l & 1: c = cnt[l] if c: if cL == 0: cL = c fL = fv[l] tL = lv[l] sL = sd[l] else: sL = sL + sd[l] + dist(tL, fv[l]) tL = lv[l] cL += c l += 1 if r & 1: r -= 1 c = cnt[r] if c: if cR == 0: cR = c fR = fv[r] tR = lv[r] sR = sd[r] else: sR = sd[r] + sR + dist(lv[r], fR) fR = fv[r] cR += c l >>= 1 r >>= 1 if cL == 0: return cR, fR, tR, sR if cR == 0: return cL, fL, tL, sL return cL + cR, fL, tR, sL + sR + dist(tL, fR) def steiner_vertices(c: int, f: int, t: int, s: int) -> int: if c == 0: return 0 if c == 1: return 1 cycle = s + dist(t, f) return (cycle // 2) + 1 def query_subtree(u: int) -> int: l = tin[u] r = tout[u] + 1 c, f, t, s = query(l, r) return steiner_vertices(c, f, t, s) def query_complement_subtree(u: int) -> int: l = tin[u] r = tout[u] + 1 c1, f1, t1, s1 = query(0, l) c2, f2, t2, s2 = query(r, N) if c1 == 0: return steiner_vertices(c2, f2, t2, s2) if c2 == 0: return steiner_vertices(c1, f1, t1, s1) c = c1 + c2 f = f1 t = t2 s = s1 + s2 + dist(t1, f2) return steiner_vertices(c, f, t, s) # ----------------------------- # Process queries # ----------------------------- Q = fs.int() out = [] append = out.append for _ in range(Q): t = fs.int() if t == 1: v = fs.int() color[v] ^= 1 update_pos(tin[v], v, color[v]) else: x = fs.int() y = fs.int() if x == y: c, f, t2_, s = query(0, N) append(str(steiner_vertices(c, f, t2_, s))) else: if is_ancestor(y, x): z = jump_up(x, depth[x] - depth[y] - 1) append(str(query_complement_subtree(z))) else: append(str(query_subtree(y))) sys.stdout.write("\n".join(out)) if __name__ == "__main__": solve()