結果
問題 |
No.899 γatheree
|
ユーザー |
|
提出日時 | 2025-01-08 01:34:58 |
言語 | PyPy3 (7.3.15) |
結果 |
WA
|
実行時間 | - |
コード長 | 8,376 bytes |
コンパイル時間 | 2,223 ms |
コンパイル使用メモリ | 82,120 KB |
実行使用メモリ | 126,696 KB |
最終ジャッジ日時 | 2025-01-08 01:35:41 |
合計ジャッジ時間 | 42,786 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 6 WA * 1 TLE * 16 |
ソースコード
from collections import deque N = int(input()) tree = [[] for _ in range(N)] for _ in range(N-1): u, v = map(int, input().split()) tree[u].append(v) tree[v].append(u) A = list(map(int, input().split())) nex = 0 que = deque([0]) parent = [-1] * N children = [[] for _ in range(N)] child_lr = [[N+1, -1] for _ in range(N)] gchild_lr = [[N+1, -1] for _ in range(N)] Idx = [-1] * N while que: node = que.popleft() par = parent[node] Idx[node] = nex nex += 1 if par >= 0: child_lr[par][0] = min(child_lr[par][0], Idx[node]) child_lr[par][1] = max(child_lr[par][1], Idx[node]) gpar = parent[par] if gpar >= 0: gchild_lr[gpar][0] = min(gchild_lr[gpar][0], Idx[node]) gchild_lr[gpar][1] = max(gchild_lr[gpar][1], Idx[node]) for nn in tree[node]: if nn == par: continue que.append(nn) children[node].append(nn) parent[nn] = node e = 0 # e = lambda: 0 composition = lambda p,q: q if p is None else p id_ = None # id_ = lambda: None BASE = 1 << 20 def op(x, y): x0, x1 = divmod(x, BASE) y0, y1 = divmod(y, BASE) s0 = x0 + y0 s1 = x1 + y1 return s0*BASE + s1 def mapping(p, x): if p is None: return x x0, x1 = divmod(x, BASE) return (x1*p)*BASE + x1 class LazyPropSegTree: def __init__(self, op, e, mapping, composition, id_, v=[]): assert (len(v) >= 0) self.n = len(v) self.log = (self.n - 1).bit_length() self.size = 1 << self.log self.d = [e for _ in range(2*self.size)] self.lz = [id_ for _ in range(self.size)] self.op = op self.e = e self.mapping = mapping self.composition = composition self.id_ = id_ for i in range(self.n): self.d[self.size + i] = v[i] for i in range(self.size - 1, 0, -1): self.update(i) def update(self, k): self.d[k] = self.op(self.d[2*k], self.d[2*k+1]) def all_apply(self, k, f): self.d[k] = self.mapping(f, self.d[k]) if k < self.size: self.lz[k] = self.composition(f, self.lz[k]) def push(self, k): self.all_apply(2*k, self.lz[k]) self.all_apply(2*k+1, self.lz[k]) self.lz[k] = self.id_ def __setitem__(self, p, x): assert (0 <= p) and (p < self.n) p += self.size for i in range(self.log, 0, -1): self.push(p >> i) self.d[p] = x for i in range(1, self.log + 1): self.update(p >> i) def __getitem__(self, p): assert (0 <= p) and (p < self.n) p += self.size for i in range(self.log, 0, -1): self.push(p >> i) return self.d[p] def prod(self, left, right): assert 0<=left and left<=right and right<=self.n if left == right: return self.e left += self.size right += self.size for i in range(self.log, 0, -1): if (((left >> i) << i) != left): self.push(left >> i) if (((right >> i) << i) != right): self.push(right >> i) sml, smr = self.e, self.e while left < right: if left & 1: sml = self.op(sml, self.d[left]) left += 1 if right & 1: right -= 1 smr = self.op(self.d[right], smr) left >>= 1 right >>= 1 return self.op(sml, smr) def all_prod(self): return self.d[1] def apply(self, p, f): assert (0 <= p) and (p < self.n) p += self.size for i in range(self.log, 0, -1): self.push(p >> i) self.d[p] = self.mapping(f, self.d[p]) for i in range(1, self.log+1): self.update(p >> i) def apply_lr(self, left, right, f): assert 0<=left and left<=right and right<=self.n if left == right: return left += self.size right += self.size for i in range(self.log, 0, -1): if (((left >> i) << i) != left): self.push(left >> i) if (((right >> i) << i) != right): self.push((right - 1) >> i) l2, r2 = left, right while left < right: if left & 1: self.all_apply(left, f) left += 1 if right & 1: right -= 1 self.all_apply(right, f) left >>= 1 right >>= 1 left, right = l2, r2 for i in range(1,self.log+1): if (((left >> i) << i) != left): self.update(left >> i) if (((right >> i) << i) != right): self.update((right-1) >> i) def max_right(self, left, g): assert (0 <= left) and (left <= self.n) assert g(self.e) if left == self.n: return self.n left += self.size for i in range(self.log, 0, -1): self.push(left >> i) sm = self.e while True: while(left % 2 == 0): left >>= 1 if not g(self.op(sm, self.d[left])): while left < self.size: self.push(left) left <<= 1 if g(self.op(sm, self.d[left])): sm = self.op(sm, self.d[left]) left += 1 return left - self.size sm = self.op(sm, self.d[left]) left += 1 if(left & -left) == left: break return self.n def min_left(self, right, g): assert (0 <= right) and (right <= self.n) assert g(self.e) if right == 0: return 0 right += self.size for i in range(self.log, 0, -1): self.push((right-1) >> i) sm = self.e while True: right -= 1 while(right > 1) and (right % 2): right >>= 1 if not g(self.op(self.d[right], sm)): while right < self.size: self.push(right) right = 2 * right + 1 if g(self.op(self.d[right], sm)): sm = self.op(self.d[right], sm) right -= 1 return right + 1 - self.size sm = self.op(self.d[right], sm) if(right & -right) == right: break return 0 data = [0] * N for i, a in zip(Idx, A): data[i] = a*BASE + 1 seg = LazyPropSegTree( op=op, e=e, mapping=mapping, composition=composition, id_=id_, v=data ) DEBUG = False if DEBUG: print(Idx) print(child_lr) print(gchild_lr) Q = int(input()) for _ in range(Q): x = int(input()) sum_ = 0 par = parent[x] if DEBUG: print('----------') print(f'{x = }, {par = }') if par >= 0: res = seg[Idx[par]] res, _ = divmod(res, BASE) sum_ += res if DEBUG: print(f'seg[Idx[par]] = {res}') seg[Idx[par]] = 1 fr, to = child_lr[par] if DEBUG: print(f'{fr = }, {to = }') res = seg.prod(fr, to+1) res, _ = divmod(res, BASE) if DEBUG: print(f'{res = }') sum_ += res seg.apply_lr(fr, to+1, 0) gpar = parent[par] if gpar >= 0: res = seg[Idx[gpar]] res, _ = divmod(res, BASE) sum_ += res if DEBUG: print(f'seg[Idx[gpar]] = {res}') seg[Idx[gpar]] = 1 cfr, cto = child_lr[x] if DEBUG: print(f'{cfr = }, {cto = }') if cfr<N+1 and -1<cto: res = seg.prod(cfr, cto+1) res, _ = divmod(res, BASE) if DEBUG: print(f'{seg.prod(cfr, cto+1) = }') print(f'{res = }') sum_ += res seg.apply_lr(cfr, cto+1, 0) gcfr, gcto = gchild_lr[x] if DEBUG: print(f'{gcfr = }, {gcto = }') if gcfr<N+1 and -1<gcto: res = seg.prod(gcfr, gcto+1) res, _ = divmod(res, BASE) if DEBUG: print(f'{res = }') sum_ += res seg.apply_lr(gcfr, gcto+1, 0) idx = Idx[x] seg[idx] = sum_*BASE + 1 print(sum_)