結果

問題 No.1641 Tree Xor Query
ユーザー rlangevinrlangevin
提出日時 2024-01-18 12:34:42
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 302 ms / 5,000 ms
コード長 5,355 bytes
コンパイル時間 2,472 ms
コンパイル使用メモリ 81,444 KB
実行使用メモリ 111,964 KB
最終ジャッジ日時 2024-01-18 12:34:48
合計ジャッジ時間 4,205 ms
ジャッジサーバーID
(参考情報)
judge15 / judge11
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 40 ms
53,460 KB
testcase_01 AC 39 ms
53,460 KB
testcase_02 AC 40 ms
53,460 KB
testcase_03 AC 43 ms
55,608 KB
testcase_04 AC 41 ms
55,604 KB
testcase_05 AC 43 ms
55,604 KB
testcase_06 AC 43 ms
55,604 KB
testcase_07 AC 41 ms
55,604 KB
testcase_08 AC 39 ms
53,460 KB
testcase_09 AC 41 ms
55,604 KB
testcase_10 AC 43 ms
55,604 KB
testcase_11 AC 41 ms
53,460 KB
testcase_12 AC 42 ms
55,604 KB
testcase_13 AC 302 ms
108,992 KB
testcase_14 AC 294 ms
108,992 KB
testcase_15 AC 94 ms
76,936 KB
testcase_16 AC 138 ms
79,468 KB
testcase_17 AC 133 ms
78,568 KB
testcase_18 AC 107 ms
78,184 KB
testcase_19 AC 99 ms
76,412 KB
testcase_20 AC 269 ms
111,964 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
input = sys.stdin.readline

class SegmentTree:
    def __init__(self,
                n,
                identity_e,
                combine_f,
                ):
        self._n = n
        self._size = 1
        while self._size < self._n:
            self._size <<= 1
        self._identity_e = identity_e
        self._combine_f = combine_f
        self._node = [self._identity_e] * (2 * self._size)

    def build(self, array):
        assert len(array) == self._n
        for index, value in enumerate(array, start=self._size):
            self._node[index] = value
        for index in range(self._size - 1, 0, -1):
            self._node[index] = self._combine_f(
                self._node[index << 1 | 0],
                self._node[index << 1 | 1]
            )

    def update(self, index, value):
        i = self._size + index
        self._node[i] = value
        while i > 1:
            i >>= 1
            self._node[i] = self._combine_f(
                self._node[i << 1 | 0],
                self._node[i << 1 | 1]
            )

    def fold(self, L, R):
        L += self._size
        R += self._size
        value_L = self._identity_e
        value_R = self._identity_e
        while L < R:
            if L & 1:
                value_L = self._combine_f(value_L, self._node[L])
                L += 1
            if R & 1:
                R -= 1
                value_R = self._combine_f(self._node[R], value_R)
            L >>= 1
            R >>= 1
        return self._combine_f(value_L, value_R)

    def get(self, p):
        return self._node[p + self._size]

    def max_right(self, l, f):
        assert 0 <= l <= self._n
        assert f(self._identity_e)
        if l == self._n:
            return self._n
        l += self._size
        sm = self._identity_e
        while True:
            while l % 2 == 0:
                l >>= 1
            if not f(self._combine_f(sm, self._node[l])):
                while l < self._size:
                    l <<= 1
                    if f(self._combine_f(sm, self._node[l])):
                        sm = self._combine_f(sm, self._node[l])
                        l += 1
                return l - self._size
            sm = self._combine_f(sm, self._node[l])
            l += 1
            if l & -l == l:
                break
        return self._n

    def min_left(self, r, f):
        assert 0 <= r <= self._n
        assert f(self._identity_e)
        if r == 0:
            return 0
        r += self._size
        sm = self._identity_e
        while True:
            r -= 1
            while r > 1 and r % 2:
                r >>= 1
            if not f(self._combine_f(self._node[r], sm)):
                while r < self._size:
                    r = 2 * r + 1
                    if f(self._combine_f(self._node[r], sm)):
                        sm = self._combine_f(self._node[r], sm)
                        r -= 1
                return r + 1 - self._size
            sm = self._combine_f(self._node[r], sm)
            if r & -r == r:
                break
        return 0

from operator import add
class EulerTour():
    def __init__(self, N):
        self.N = N
        self._in = [0] * N
        self._out = [0] * N
        self.depth = [-1] * N
        self.weight = [0] * N
        self.par = [-1] * N
        self.tour = [-1] * (2 * N)
        self.G = [[] for i in range(N)]
        
    def add_edge(self, u, v, w):
        self.G[u].append((v, w))
        self.G[v].append((u, w))
        
    def build(self, root=0):
        stack = [root]
        self.depth[root] = 0
        for i in range(2 * self.N):
            s = stack.pop()
            if s >= 0:
                stack.append(~s)
                self.tour[i] = s
                self._in[s] = i
                for u, w in self.G[s]:
                    if u == self.par[s]:
                        continue                        
                    self.par[u] = s
                    self.weight[u] = w
                    self.depth[u] = self.depth[s] + 1
                    stack.append(u)
            else:
                s = ~s
                self._out[s] = i
                self.tour[i] = ~s
                
    def get_path(self, u, v):
        d = self.elist.fold(0, self._in[u] + 1) + self.elist.fold(0, self._in[v] + 1)
        lca = self.get_lca(u, v)
        d -= 2 * self.elist.fold(0, self._in[lca] + 1)
        return d
        
    def update(self, u, v, w):
        e = v
        if self.par[u] == v:
            e = u
        self.elist.update(self._in[e], w)
        self.elist.update(self._out[e], -w)
        
    def get_lca(self, u, v):
        if self._in[u] > self._in[v]:
            u, v = v, u
        lca = self.dlist.fold(self._in[u], self._in[v] + 1) % self.N
        return lca
    
    
N, Q = map(int, input().split())
C = list(map(int, input().split()))
T = EulerTour(N)
for i in range(N - 1):
    a, b = map(int, input().split())
    a, b = a - 1, b - 1
    T.add_edge(a, b, 1)
    
T.build(0)
from operator import xor
ST = SegmentTree(2 * N, 0, xor)
_in, _out = T._in, T._out
for i in range(N):
    ST.update(_in[i], C[i])
    
for _ in range(Q):
    t, x, y = map(int, input().split())
    x -= 1
    if t == 1:
        now = ST.get(_in[x])
        ST.update(_in[x], now^y)
    else:
        print(ST.fold(_in[x], _out[x] + 1))
0