結果
| 問題 | No.3442 Good Vertex Connectivity |
| コンテスト | |
| ユーザー |
👑 potato167
|
| 提出日時 | 2026-02-03 16:48:40 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
AC
|
| 実行時間 | 2,421 ms / 3,000 ms |
| コード長 | 9,243 bytes |
| 記録 | |
| コンパイル時間 | 626 ms |
| コンパイル使用メモリ | 82,332 KB |
| 実行使用メモリ | 208,608 KB |
| 最終ジャッジ日時 | 2026-02-06 20:56:53 |
| 合計ジャッジ時間 | 123,738 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 69 |
ソースコード
import sys
from array import array
# ------------------------------------------------------------
# Fast scanner (bytes -> int)
# ------------------------------------------------------------
class FastScanner:
__slots__ = ("data", "i", "n")
def __init__(self):
self.data = sys.stdin.buffer.read()
self.i = 0
self.n = len(self.data)
def int(self) -> int:
data = self.data
n = self.n
i = self.i
while i < n and data[i] <= 32:
i += 1
num = 0
while i < n and data[i] > 32:
num = num * 10 + (data[i] - 48)
i += 1
self.i = i
return num
def solve():
fs = FastScanner()
N = fs.int()
g = [[] for _ in range(N + 1)]
for _ in range(N - 1):
a = fs.int()
b = fs.int()
g[a].append(b)
g[b].append(a)
color = [0] * (N + 1)
for i in range(1, N + 1):
color[i] = fs.int()
# -----------------------------
# Rooted tree at 1:
# tin/tout, depth, parent
# Euler-in for subtree interval
# Euler tour for RMQ-LCA: euler2, dep2
# -----------------------------
depth = [0] * (N + 1)
parent = [0] * (N + 1)
tin = [0] * (N + 1)
tout = [0] * (N + 1)
euler_in = [0] * N
first = [-1] * (N + 1)
euler2 = array('I')
dep2 = array('I')
timer = 0
it = [0] * (N + 1)
stack = [1]
parent[1] = 0
depth[1] = 0
while stack:
v = stack[-1]
if it[v] == 0:
tin[v] = timer
euler_in[timer] = v
timer += 1
if first[v] == -1:
first[v] = len(euler2)
euler2.append(v)
dep2.append(depth[v])
if it[v] < len(g[v]):
to = g[v][it[v]]
it[v] += 1
if to == parent[v]:
continue
parent[to] = v
depth[to] = depth[v] + 1
stack.append(to)
else:
tout[v] = timer - 1
stack.pop()
if stack:
p = stack[-1]
euler2.append(p)
dep2.append(depth[p])
M = len(euler2) # ~ 2N-1
# -----------------------------
# Sparse Table for RMQ over dep2
# store indices into euler2
# -----------------------------
logs = [0] * (M + 1)
for i in range(2, M + 1):
logs[i] = logs[i >> 1] + 1
K = logs[M] + 1
st = [None] * K
st0 = array('I', range(M))
st[0] = st0
j = 1
while (1 << j) <= M:
prev = st[j - 1]
span = 1 << (j - 1)
new_len = M - (1 << j) + 1
cur = array('I', [0]) * new_len
# argmin depth
for i in range(new_len):
i1 = prev[i]
i2 = prev[i + span]
cur[i] = i1 if dep2[i1] <= dep2[i2] else i2
st[j] = cur
j += 1
# local bindings for speed
_first = first
_logs = logs
_st = st
_euler2 = euler2
_dep2 = dep2
_depth = depth
def lca(a: int, b: int) -> int:
ia = _first[a]
ib = _first[b]
if ia > ib:
ia, ib = ib, ia
k = _logs[ib - ia + 1]
t = _st[k]
i1 = t[ia]
i2 = t[ib - (1 << k) + 1]
return _euler2[i1] if _dep2[i1] <= _dep2[i2] else _euler2[i2]
def dist(a: int, b: int) -> int:
c = lca(a, b)
return _depth[a] + _depth[b] - 2 * _depth[c]
def is_ancestor(a: int, b: int) -> bool:
return tin[a] <= tin[b] and tout[b] <= tout[a]
# -----------------------------
# Binary lifting for jump_up only (use array('I') to reduce overhead)
# -----------------------------
LOG = (N).bit_length()
up = [None] * LOG
up0 = array('I', parent) # parent[0..N]
up[0] = up0
for k in range(1, LOG):
prev = up[k - 1]
cur = array('I', [0]) * (N + 1)
for v in range(1, N + 1):
cur[v] = prev[prev[v]]
up[k] = cur
def jump_up(v: int, k: int) -> int:
bit = 0
while k:
if k & 1:
v = up[bit][v]
k >>= 1
bit += 1
return v
# -----------------------------
# Segment Tree over Euler-in [0..N)
# store as 4 arrays to avoid tuple allocations:
# cnt, fv, lv, sd(sum distances consecutive)
# Use 0 as "empty vertex" sentinel.
# -----------------------------
size = 1
while size < N:
size <<= 1
segN = 2 * size
cnt = array('I', [0]) * segN
fv = array('I', [0]) * segN
lv = array('I', [0]) * segN
sd = array('Q', [0]) * segN # 64-bit sum
# build leaves
base = size
for i in range(N):
v = euler_in[i]
if color[v]:
idx = base + i
cnt[idx] = 1
fv[idx] = v
lv[idx] = v
sd[idx] = 0
# merge function inline for build/update/query
for i in range(size - 1, 0, -1):
L = i << 1
R = L | 1
cL = cnt[L]
cR = cnt[R]
if cL == 0:
cnt[i] = cR
fv[i] = fv[R]
lv[i] = lv[R]
sd[i] = sd[R]
elif cR == 0:
cnt[i] = cL
fv[i] = fv[L]
lv[i] = lv[L]
sd[i] = sd[L]
else:
cnt[i] = cL + cR
fv[i] = fv[L]
lv[i] = lv[R]
sd[i] = sd[L] + sd[R] + dist(lv[L], fv[R])
def update_pos(pos: int, v: int, on: int):
i = base + pos
if on:
cnt[i] = 1
fv[i] = v
lv[i] = v
sd[i] = 0
else:
cnt[i] = 0
fv[i] = 0
lv[i] = 0
sd[i] = 0
i >>= 1
while i:
L = i << 1
R = L | 1
cL = cnt[L]
cR = cnt[R]
if cL == 0:
cnt[i] = cR
fv[i] = fv[R]
lv[i] = lv[R]
sd[i] = sd[R]
elif cR == 0:
cnt[i] = cL
fv[i] = fv[L]
lv[i] = lv[L]
sd[i] = sd[L]
else:
cnt[i] = cL + cR
fv[i] = fv[L]
lv[i] = lv[R]
sd[i] = sd[L] + sd[R] + dist(lv[L], fv[R])
i >>= 1
def query(l: int, r: int):
# returns (c, f, t, s) as 4 ints (no tuple nodes stored in seg)
cL = 0
fL = 0
tL = 0
sL = 0
cR = 0
fR = 0
tR = 0
sR = 0
l += base
r += base
while l < r:
if l & 1:
c = cnt[l]
if c:
if cL == 0:
cL = c
fL = fv[l]
tL = lv[l]
sL = sd[l]
else:
sL = sL + sd[l] + dist(tL, fv[l])
tL = lv[l]
cL += c
l += 1
if r & 1:
r -= 1
c = cnt[r]
if c:
if cR == 0:
cR = c
fR = fv[r]
tR = lv[r]
sR = sd[r]
else:
sR = sd[r] + sR + dist(lv[r], fR)
fR = fv[r]
cR += c
l >>= 1
r >>= 1
if cL == 0:
return cR, fR, tR, sR
if cR == 0:
return cL, fL, tL, sL
return cL + cR, fL, tR, sL + sR + dist(tL, fR)
def steiner_vertices(c: int, f: int, t: int, s: int) -> int:
if c == 0:
return 0
if c == 1:
return 1
cycle = s + dist(t, f)
return (cycle // 2) + 1
def query_subtree(u: int) -> int:
l = tin[u]
r = tout[u] + 1
c, f, t, s = query(l, r)
return steiner_vertices(c, f, t, s)
def query_complement_subtree(u: int) -> int:
l = tin[u]
r = tout[u] + 1
c1, f1, t1, s1 = query(0, l)
c2, f2, t2, s2 = query(r, N)
if c1 == 0:
return steiner_vertices(c2, f2, t2, s2)
if c2 == 0:
return steiner_vertices(c1, f1, t1, s1)
c = c1 + c2
f = f1
t = t2
s = s1 + s2 + dist(t1, f2)
return steiner_vertices(c, f, t, s)
# -----------------------------
# Process queries
# -----------------------------
Q = fs.int()
out = []
append = out.append
for _ in range(Q):
t = fs.int()
if t == 1:
v = fs.int()
color[v] ^= 1
update_pos(tin[v], v, color[v])
else:
x = fs.int()
y = fs.int()
if x == y:
c, f, t2_, s = query(0, N)
append(str(steiner_vertices(c, f, t2_, s)))
else:
if is_ancestor(y, x):
z = jump_up(x, depth[x] - depth[y] - 1)
append(str(query_complement_subtree(z)))
else:
append(str(query_subtree(y)))
sys.stdout.write("\n".join(out))
if __name__ == "__main__":
solve()
potato167