結果

問題 No.899 γatheree
ユーザー Shinya Fujita
提出日時 2025-01-09 01:18:08
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 8,445 bytes
コンパイル時間 1,023 ms
コンパイル使用メモリ 81,792 KB
実行使用メモリ 126,812 KB
最終ジャッジ日時 2025-01-09 01:19:02
合計ジャッジ時間 42,150 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 5 TLE * 18
権限があれば一括ダウンロードができます

ソースコード

diff #

DEBUG = False

from collections import deque

N = int(input())
tree = [[] for _ in range(N)]
for _ in range(N-1):
    u, v = map(int, input().split())
    tree[u].append(v)
    tree[v].append(u)

A = list(map(int, input().split()))

nex = 0
que = deque([0])
parent = [-1] * N
children = [[] for _ in range(N)]
child_lr = [[N+1, -1] for _ in range(N)]
gchild_lr = [[N+1, -1] for _ in range(N)]
Idx = [-1] * N
while que:
    node = que.popleft()
    par = parent[node]
    Idx[node] = nex
    nex += 1
    if par >= 0:
        child_lr[par][0] = min(child_lr[par][0], Idx[node])
        child_lr[par][1] = max(child_lr[par][1], Idx[node])
        gpar = parent[par]
        if gpar >= 0:
            gchild_lr[gpar][0] = min(gchild_lr[gpar][0], Idx[node])
            gchild_lr[gpar][1] = max(gchild_lr[gpar][1], Idx[node])
    for nn in tree[node]:
        if nn == par:
            continue
        que.append(nn)
        children[node].append(nn)
        parent[nn] = node


e = 0 # e = lambda: 0
composition = lambda p,q: q if p is None else p
id_ = None # id_ = lambda: None
BASE = 1 << 20
def op(x, y):
    x0, x1 = divmod(x, BASE)
    y0, y1 = divmod(y, BASE)
    s0 = x0 + y0
    s1 = x1 + y1
    return s0*BASE + s1

def mapping(p, x):
    if p is None:
        return x
    x0, x1 = divmod(x, BASE)
    return (x1*p)*BASE + x1

class LazyPropSegTree:
    def __init__(self, op, e, mapping, composition, id_, v=[]):
        assert (len(v) >= 0)
        self.n = len(v)
        self.log = (self.n - 1).bit_length()
        self.size = 1 << self.log
        self.d = [e for _ in range(2*self.size)]
        self.lz = [id_ for _ in range(self.size)]
        self.op = op
        self.e = e
        self.mapping = mapping
        self.composition = composition
        self.id_ = id_

        for i in range(self.n):
            self.d[self.size + i] = v[i]
        for i in range(self.size - 1, 0, -1):
            self.update(i)

    def update(self, k):
        self.d[k] = self.op(self.d[2*k], self.d[2*k+1])
    
    def all_apply(self, k, f):
        self.d[k] = self.mapping(f, self.d[k])
        if k < self.size:
            self.lz[k] = self.composition(f, self.lz[k])

    def push(self, k):
        self.all_apply(2*k, self.lz[k])
        self.all_apply(2*k+1, self.lz[k])
        self.lz[k] = self.id_

    def __setitem__(self, p, x):
        assert (0 <= p) and (p < self.n)
        p += self.size
        for i in range(self.log, 0, -1):
            self.push(p >> i)
        self.d[p] = x
        for i in range(1, self.log + 1):
            self.update(p >> i)

    def __getitem__(self, p):
        assert (0 <= p) and (p < self.n)
        p += self.size
        for i in range(self.log, 0, -1):
            self.push(p >> i)
        return self.d[p]

    def prod(self, left, right):
        assert 0<=left and left<=right and right<=self.n
        if left == right:
            return self.e
        left += self.size
        right += self.size
        for i in range(self.log, 0, -1):
            if (((left >> i) << i) != left):
                self.push(left >> i)
            if (((right >> i) << i) != right):
                self.push(right >> i)
        sml, smr = self.e, self.e
        while left < right:
            if left & 1:
                sml = self.op(sml, self.d[left])
                left += 1
            if right & 1:
                right -= 1
                smr = self.op(self.d[right], smr)
            left >>= 1
            right >>= 1
        return self.op(sml, smr)

    def all_prod(self):
        return self.d[1]

    def apply(self, p, f):
        assert (0 <= p) and (p < self.n)
        p += self.size
        for i in range(self.log, 0, -1):
            self.push(p >> i)
        self.d[p] = self.mapping(f, self.d[p])
        for i in range(1, self.log+1):
            self.update(p >> i)

    def apply_lr(self, left, right, f):
        assert 0<=left and left<=right and right<=self.n
        if left == right:
            return
        left += self.size
        right += self.size
        for i in range(self.log, 0, -1):
            if (((left >> i) << i) != left):
                self.push(left >> i)
            if (((right >> i) << i) != right):
                self.push((right - 1) >> i)
        l2, r2 = left, right
        while left < right:
            if left & 1:
                self.all_apply(left, f)
                left += 1
            if right & 1:
                right -= 1
                self.all_apply(right, f)
            left >>= 1
            right >>= 1
        left, right = l2, r2
        for i in range(1,self.log+1):
            if (((left >> i) << i) != left):
                self.update(left >> i)
            if (((right >> i) << i) != right):
                self.update((right-1) >> i)
    
    def max_right(self, left, g):
        assert (0 <= left) and (left <= self.n)
        assert g(self.e)
        if left == self.n:
            return self.n
        left += self.size
        for i in range(self.log, 0, -1):
            self.push(left >> i)
        sm = self.e
        while True:
            while(left % 2 == 0):
                left >>= 1
            if not g(self.op(sm, self.d[left])):
                while left < self.size:
                    self.push(left)
                    left <<= 1
                    if g(self.op(sm, self.d[left])):
                        sm = self.op(sm, self.d[left])
                        left += 1
                return left - self.size
            sm = self.op(sm, self.d[left])
            left += 1
            if(left & -left) == left:
                break
        return self.n

    def min_left(self, right, g):
        assert (0 <= right) and (right <= self.n)
        assert g(self.e)
        if right == 0:
            return 0
        right += self.size
        for i in range(self.log, 0, -1):
            self.push((right-1) >> i)
        sm = self.e
        while True:
            right -= 1
            while(right > 1) and (right % 2):
                right >>= 1
            if not g(self.op(self.d[right], sm)):
                while right < self.size:
                    self.push(right)
                    right = 2 * right + 1
                    if g(self.op(self.d[right], sm)):
                        sm = self.op(self.d[right], sm)
                        right -= 1
                return right + 1 - self.size
            sm = self.op(self.d[right], sm)
            if(right & -right) == right:
                break
        return 0


data = [0] * N
for i, a in zip(Idx, A):
    data[i] = a*BASE + 1


seg = LazyPropSegTree(
    op=op, e=e, mapping=mapping, composition=composition, id_=id_,
    v=data
)

if DEBUG:
    print(Idx)


Q = int(input())
for _ in range(Q):
    x = int(input())
    sum_ = 0
    par = parent[x]
    if DEBUG:
        print('----------')
        print(f'{x = }, {par = }')
    if par >= 0:
        res = seg[Idx[par]]
        res, _ = divmod(res, BASE)
        sum_ += res
        if DEBUG:
            print(f'seg[Idx[par]] = {res}')
        seg[Idx[par]] = 1
        
        fr, to = child_lr[par]
        if DEBUG:
            print(f'{fr = }, {to = }')
        
        res = seg.prod(fr, to+1)
        res, _ = divmod(res, BASE)
        if DEBUG:
            print(f'{res = }')
            
        sum_ += res
        seg.apply_lr(fr, to+1, 0)
        gpar = parent[par]
        if gpar >= 0:
            res = seg[Idx[gpar]]
            res, _ = divmod(res, BASE)
            sum_ += res
            if DEBUG:
                print(f'seg[Idx[gpar]] = {res}')
                
            seg[Idx[gpar]] = 1
    else:
        res = seg[Idx[x]]
        res, _ = divmod(res, BASE)
        sum_ += res
        if DEBUG:
            print(f'{par = }')
            print(f'{res = }')
    
    cfr, cto = child_lr[x]
    if DEBUG:
        print(f'{cfr = }, {cto = }')
    
    if cfr<N+1 and -1<cto:
        res = seg.prod(cfr, cto+1)
        res, _ = divmod(res, BASE)
        if DEBUG:
            print(f'{res = }')
        sum_ += res
        seg.apply_lr(cfr, cto+1, 0)
    
    gcfr, gcto = gchild_lr[x]
    if DEBUG:
        print(f'{gcfr = }, {gcto = }')
    if gcfr<N+1 and -1<gcto:
        res = seg.prod(gcfr, gcto+1)
        res, _ = divmod(res, BASE)
        if DEBUG:
            print(f'{res = }')
        sum_ += res
        seg.apply_lr(gcfr, gcto+1, 0)
    
    seg[Idx[x]] = sum_*BASE + 1
    print(sum_)
0