結果

問題 No.2676 A Tourist
ユーザー nu50218nu50218
提出日時 2024-03-13 20:20:18
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,331 ms / 5,000 ms
コード長 11,880 bytes
コンパイル時間 367 ms
コンパイル使用メモリ 82,288 KB
実行使用メモリ 184,392 KB
最終ジャッジ日時 2024-09-29 23:18:24
合計ジャッジ時間 23,424 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 58 ms
70,656 KB
testcase_01 AC 590 ms
104,216 KB
testcase_02 AC 1,331 ms
146,860 KB
testcase_03 AC 680 ms
147,300 KB
testcase_04 AC 966 ms
146,736 KB
testcase_05 AC 1,098 ms
146,608 KB
testcase_06 AC 407 ms
107,532 KB
testcase_07 AC 666 ms
170,672 KB
testcase_08 AC 402 ms
167,800 KB
testcase_09 AC 639 ms
170,676 KB
testcase_10 AC 512 ms
171,616 KB
testcase_11 AC 1,000 ms
143,728 KB
testcase_12 AC 329 ms
107,748 KB
testcase_13 AC 536 ms
162,008 KB
testcase_14 AC 487 ms
162,512 KB
testcase_15 AC 558 ms
162,004 KB
testcase_16 AC 592 ms
104,968 KB
testcase_17 AC 1,276 ms
146,916 KB
testcase_18 AC 694 ms
147,544 KB
testcase_19 AC 1,189 ms
146,668 KB
testcase_20 AC 1,061 ms
147,668 KB
testcase_21 AC 165 ms
79,144 KB
testcase_22 AC 292 ms
82,992 KB
testcase_23 AC 165 ms
80,572 KB
testcase_24 AC 258 ms
82,488 KB
testcase_25 AC 183 ms
80,848 KB
testcase_26 AC 60 ms
70,528 KB
testcase_27 AC 313 ms
114,064 KB
testcase_28 AC 759 ms
184,140 KB
testcase_29 AC 473 ms
184,392 KB
testcase_30 AC 675 ms
184,268 KB
testcase_31 AC 565 ms
184,388 KB
testcase_32 AC 368 ms
150,524 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#!/usr/bin/pypy3

# 以下の提出をベースに改変
# https://judge.yosupo.jp/submission/160607
# アルゴリズム的には解説を実装したもの

import os
from __pypy__.builders import StringBuilder

class FastO():

  sb = StringBuilder()

  @classmethod
  def write(cls, *args, sep: str=' ', end: str='\n', flush: bool=False) -> None:
    append = cls.sb.append
    for i in range(len(args)-1):
      append(str(args[i]))
      append(sep)
    if args:
      append(str(args[-1]))
    append(end)
    if flush:
      cls.flush()

  @classmethod
  def flush(cls) -> None:
    os.write(1, cls.sb.build().encode())
    cls.sb = StringBuilder()

write = FastO.write
flush = FastO.flush

import sys
input = lambda: sys.stdin.buffer.readline().rstrip()

#  -----------------------  #

from typing import Any, List, Generator, Tuple
from types import GeneratorType
from __pypy__ import newlist_hint

class HLD():

  def __init__(self, G: List[List[int]], root: int):
    n = len(G)
    self.n: int = n
    self.G: List[List[int]] = G
    self.size: List[int] = [1] * n
    self.par: List[int] = [-1] * n
    self.dep: List[int] = [-1] * n
    self.nodein: List[int] = [-1] * n
    self.nodeout: List[int] = [-1] * n
    self.head: List[int] = [0] * n
    self.hld: List[int] = [0] * n
    self.dfs(root)

  def dfs(self, root: int) -> None:
    dep, par, size, G = self.dep, self.par, self.size, self.G
    dep[root] = 0
    stack = [root]
    while stack:
      v = stack.pop()
      if v >= 0:
        for x in G[v]:
          if dep[x] != -1:
            continue
          dep[x] = dep[v] + 1
          stack.append(~x)
          stack.append(x)
      else:
        v = ~v
        s = 1
        for i, x in enumerate(G[v]):
          if dep[x] < dep[v]:
            par[v] = x
            continue
          s += size[x]
          if size[x] > size[G[v][0]]:
            G[v][0], G[v][i] = G[v][i], G[v][0]
        size[v] = s
    head, nodein, nodeout, hld = self.head, self.nodein, self.nodeout, self.hld
    curtime = 0
    stack = [~root, root]
    while stack:
      v = stack.pop()
      if v >= 0:
        if par[v] == -1:
          head[v] = v
        nodein[v] = curtime
        hld[curtime] = v
        curtime += 1
        for x in reversed(G[v]):
          if x == par[v]:
            continue
          head[x] = head[v] if x == G[v][0] else x
          stack.append(~x)
          stack.append(x)
      else:
        nodeout[~v] = curtime

  def path_kth_elm(self, s: int, t: int, k: int) -> int:
    head, dep, par = self.head, self.dep, self.par
    lca = self.lca(s, t)
    d = dep[s] + dep[t] - 2*dep[lca]
    if d < k:
      return -1
    if dep[s] - dep[lca] < k:
      s = t
      k = d - k
    hs = head[s]
    while dep[s] - dep[hs] < k:
      k -= dep[s] - dep[hs] + 1
      s = par[hs]
      hs = head[s]
    return self.hld[self.nodein[s] - k]

  def lca(self, u: int, v: int) -> int:
    nodein, head, par = self.nodein, self.head, self.par
    while True:
      if nodein[u] > nodein[v]:
        u, v = v, u
      if head[u] == head[v]:
        return u
      v = par[head[v]]

  def dist(self, u: int, v: int) -> int:
    lca = self.lca(u, v)
    dep = self.dep
    return dep[u] + dep[v] - 2 * dep[lca]

  def build_list(self, a: List[Any]) -> List[Any]:
    return [a[e] for e in self.hld]

  def for_each_vertex(self, u: int, v: int) -> Generator[Tuple[int, int], None, None]:
    head, nodein, dep, par = self.head, self.nodein, self.dep, self.par
    while head[u] != head[v]:
      if dep[head[u]] < dep[head[v]]:
        u, v = v, u
      yield nodein[head[u]], nodein[u]+1
      u = par[head[u]]
    if dep[u] < dep[v]:
      u, v = v, u
    yield nodein[v], nodein[u]+1

  def for_each_vertex_subtree(self, v: int) -> Tuple[int, int]:
    return self.nodein[v], self.nodeout[v]


from abc import ABC, abstractmethod
from typing import TypeVar, Generic, Union, Iterable, Callable, List
T = TypeVar('T')

class SegmentTreeInterface(ABC, Generic[T]):

  @abstractmethod
  def __init__(self, n_or_a: Union[int, Iterable[T]],
               op: Callable[[T, T], T],
               e: T):
    raise NotImplementedError

  @abstractmethod
  def set(self, k: int, v: T) -> None:
    raise NotImplementedError

  @abstractmethod
  def get(self, k: int) -> T:
    raise NotImplementedError

  @abstractmethod
  def prod(self, l: int, r: int) -> T:
    raise NotImplementedError

  @abstractmethod
  def all_prod(self) -> T:
    raise NotImplementedError

  @abstractmethod
  def max_right(self, l: int, f: Callable[[T], bool]) -> int:
    raise NotImplementedError

  @abstractmethod
  def min_left(self, r: int, f: Callable[[T], bool]) -> int:
    raise NotImplementedError

  @abstractmethod
  def tolist(self) -> List[T]:
    raise NotImplementedError

  @abstractmethod
  def __getitem__(self, k: int) -> T:
    raise NotImplementedError

  @abstractmethod
  def __setitem__(self, k: int, v: T) -> None:
    raise NotImplementedError
  
  @abstractmethod
  def __str__(self):
    raise NotImplementedError
  
  @abstractmethod
  def __repr__(self):
    raise NotImplementedError
  
from typing import Generic, Iterable, TypeVar, Callable, Union, List
T = TypeVar('T')

class SegmentTree(SegmentTreeInterface, Generic[T]):

  def __init__(self, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
    '''Build a new SegmentTree. / O(N)'''
    self._op = op
    self._e = e
    if isinstance(n_or_a, int):
      self._n = n_or_a
      self._log  = (self._n - 1).bit_length()
      self._size = 1 << self._log
      self._data = [e] * (self._size << 1)
    else:
      n_or_a = list(n_or_a)
      self._n = len(n_or_a)
      self._log  = (self._n - 1).bit_length()
      self._size = 1 << self._log
      _data = [e] * (self._size << 1)
      _data[self._size:self._size+self._n] = n_or_a
      for i in range(self._size-1, 0, -1):
        _data[i] = op(_data[i<<1], _data[i<<1|1])
      self._data = _data

  def set(self, k: int, v: T) -> None:
    '''Update a[k] <- x. / O(logN)'''
    assert -self._n <= k < self._n, \
        f'IndexError: SegmentTree.set({k}, {v}), n={self._n}'
    if k < 0:
      k += self._n
    k += self._size
    self._data[k] = v
    for _ in range(self._log):
      k >>= 1
      self._data[k] = self._op(self._data[k<<1], self._data[k<<1|1])

  def get(self, k: int) -> T:
    '''Return a[k]. / O(1)'''
    assert -self._n <= k < self._n, \
        f'IndexError: SegmentTree.get({k}), n={self._n}'
    if k < 0:
      k += self._n
    return self._data[k+self._size]

  def prod(self, l: int, r: int) -> T:
    '''Return op([l, r)). / O(logN)'''
    assert 0 <= l <= r <= self._n, \
        f'IndexError: SegmentTree.prod({l}, {r})'
    l += self._size
    r += self._size
    lres = self._e
    rres = self._e
    while l < r:
      if l & 1:
        lres = self._op(lres, self._data[l])
        l += 1
      if r & 1:
        rres = self._op(self._data[r^1], rres)
      l >>= 1
      r >>= 1
    return self._op(lres, rres)

  def all_prod(self) -> T:
    '''Return op([0, n)). / O(1)'''
    return self._data[1]

  def max_right(self, l: int, f: Callable[[T], bool]) -> int:
    '''Find the largest index R s.t. f([l, R)) == True. / O(logN)'''
    assert 0 <= l <= self._n, \
        f'IndexError: SegmentTree.max_right({l}, f) index out of range'
    assert f(self._e), \
        f'SegmentTree.max_right({l}, f), f({self._e}) must be true.'
    if l == self._n:
      return self._n 
    l += self._size
    s = self._e
    while True:
      while l & 1 == 0:
        l >>= 1
      if not f(self._op(s, self._data[l])):
        while l < self._size:
          l <<= 1
          if f(self._op(s, self._data[l])):
            s = self._op(s, self._data[l])
            l |= 1
        return l - self._size
      s = self._op(s, self._data[l])
      l += 1
      if l & -l == l:
        break
    return self._n

  def min_left(self, r: int, f: Callable[[T], bool]) -> int:
    '''Find the smallest index L s.t. f([L, r)) == True. / O(logN)'''
    assert 0 <= r <= self._n, \
        f'IndexError: SegmentTree.min_left({r}, f) index out of range'
    assert f(self._e), \
        f'SegmentTree.min_left({r}, f), f({self._e}) must be true.'
    if r == 0:
      return 0 
    r += self._size
    s = self._e
    while True:
      r -= 1
      while r > 1 and r & 1:
        r >>= 1
      if not f(self._op(self._data[r], s)):
        while r < self._size:
          r = r << 1 | 1
          if f(self._op(self._data[r], s)):
            s = self._op(self._data[r], s)
            r ^= 1
        return r + 1 - self._size
      s = self._op(self._data[r], s)
      if r & -r == r:
        break 
    return 0

  def tolist(self) -> List[T]:
    '''Return List[self]. / O(N)'''
    return [self.get(i) for i in range(self._n)]

  def show(self) -> None:
    '''Debug. / O(N)'''
    print('<SegmentTree> [\n' + '\n'.join(['  ' + ' '.join(map(str, [self._data[(1<<i)+j] for j in range(1<<i)])) for i in range(self._log+1)]) + '\n]')

  def __getitem__(self, k: int) -> T:
    assert -self._n <= k < self._n, \
        f'IndexError: SegmentTree.__getitem__({k}), n={self._n}'
    return self.get(k)

  def __setitem__(self, k: int, v: T):
    assert -self._n <= k < self._n, \
        f'IndexError: SegmentTree.__setitem__{k}, {v}), n={self._n}'
    self.set(k, v)

  def __str__(self):
    return str(self.tolist())

  def __repr__(self):
    return f'SegmentTree({self})'

# def op(s, t):
#   return

# e = None

from typing import Union, Iterable, Callable, TypeVar, Generic
T = TypeVar('T')

class HLDSegmentTree(Generic[T]):

  def __init__(self, hld: HLD, n_or_a: Union[int, Iterable[T]], op: Callable[[T, T], T], e: T):
    self.hld: HLD = hld
    n_or_a = n_or_a if isinstance(n_or_a, int) else self.hld.build_list(list(n_or_a))
    self.seg: SegmentTree[T] = SegmentTree(n_or_a, op, e)
    self.op: Callable[[T, T], T] = op
    self.e: T = e

  def path_prod(self, u: int, v: int) -> T:
    head, nodein, dep, par = self.hld.head, self.hld.nodein, self.hld.dep, self.hld.par
    res = self.e
    while head[u] != head[v]:
      if dep[head[u]] < dep[head[v]]:
        u, v = v, u
      res = self.op(res, self.seg.prod(nodein[head[u]], nodein[u]+1))
      u = par[head[u]]
    if dep[u] < dep[v]:
      u, v = v, u
    return self.op(res, self.seg.prod(nodein[v], nodein[u]+1))

  def get(self, k: int) -> T:
    return self.seg[self.hld.nodein[k]]

  def set(self, k: int, v: T) -> None:
    self.seg[self.hld.nodein[k]] = v

  __getitem__ = get
  __setitem__ = set

  def subtree_prod(self, v: int) -> T:
    return self.seg.prod(self.hld.nodein[v], self.hld.nodeout[v])


#  -----------------------  #

# セグ木定義
e = 0
def op(s, t):
  return s + t

#  -----------------------  #

n, q = map(int, input().split())

a = list(map(int, input().split()))
a.insert(0, 0)

# 入力受取 & グラフ構築
G = [[] for _ in range(n + 1)]
G[0].append(1)
G[1].append(0)
for _ in range(n-1):
  u, v = map(int, input().split())
  G[u].append(v)
  G[v].append(u)

# BFS して親を計算
que = [0]
par = [-1] * (n + 1)
while que:
  v = que.pop()
  for u in G[v]:
    if u == par[v]:
      continue
    par[u] = v
    que.append(u)

# CSum の初期値を計算
csum = [0] * (n + 1)
for i in range(1, n + 1):
  if par[i] != -1:
    csum[par[i]] += a[i]

# HLD を準備
hld = HLD(G, 0)
seg_a = HLDSegmentTree(hld, a, op, e)
seg_csum = HLDSegmentTree(hld, csum, op, e)

# クエリに答える
for _ in range(q):
  t, arg1, arg2 = map(int,input().split())

  # クエリ0
  if t == 0:
    u, x = arg1, arg2

    a[u] += x
    seg_a[u] += x

    csum[par[u]] += x
    seg_csum[par[u]] += x
    
    continue

  # クエリ1
  if t == 1:
    u, v = arg1, arg2

    lca = hld.lca(u,v)
    p = par[lca]

    ans = seg_csum.path_prod(u, v)
    ans += a[p]
    ans += a[lca]

    write(ans)

flush()
0