結果

問題 No.399 動的な領主
ユーザー titan23titan23
提出日時 2023-05-06 13:43:25
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 9,001 bytes
コンパイル時間 258 ms
コンパイル使用メモリ 81,920 KB
実行使用メモリ 203,852 KB
最終ジャッジ日時 2024-11-23 21:45:30
合計ジャッジ時間 38,152 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 81 ms
74,112 KB
testcase_01 AC 77 ms
161,880 KB
testcase_02 AC 150 ms
83,328 KB
testcase_03 AC 154 ms
174,908 KB
testcase_04 AC 516 ms
87,636 KB
testcase_05 TLE -
testcase_06 TLE -
testcase_07 TLE -
testcase_08 TLE -
testcase_09 TLE -
testcase_10 AC 779 ms
92,744 KB
testcase_11 AC 1,596 ms
183,788 KB
testcase_12 TLE -
testcase_13 TLE -
testcase_14 AC 680 ms
92,288 KB
testcase_15 TLE -
testcase_16 TLE -
testcase_17 TLE -
testcase_18 TLE -
権限があれば一括ダウンロードができます

ソースコード

diff #

from array import array
from typing import Generic, List, TypeVar, Callable, Iterable, Optional, Union
T = TypeVar('T')
F = TypeVar('F')

class LinkCutTree(Generic[T, F]):

  # パスクエリ全部載せLinkCutTree
  # - link / cut / merge / split
  # - prod / apply / getitem / setitem
  # - root / same
  # - lca / path_kth_elm
  # など

  # opがいらないならupdateを即returnするように変更したり、
  # 可換opならupdateを短縮したりなど

  # opをするならeは必須 <- 場合分けしてもよさそう?
  # idは無くてもよいが、あると(strategyの問題で)速くなるため推奨

  def __init__(self, n_or_a: Union[int, Iterable[T]], \
              op: Callable[[T, T], T]=lambda x, y: None, \
              mapping: Callable[[F, T], T]=lambda x, y: None, \
              composition: Callable[[F, F], F]=lambda x, y: None, \
              e: T=None, id: F=None):
    # self.op = op
    # self.mapping = mapping
    # self.composition = composition
    # self.e = e
    # self.id = id
    self.key: List[T] = list(n_or_a)
    self.n = len(self.key)
    self.key.append(0)
    self.data : List[T] = self.key[:]
    self.lazy : List[F] = [0] * (self.n+1)
    self.arr  : array[int] = array('I', [self.n, self.n, self.n, 0] * (self.n+1))
    # node.left  : arr[node<<2|0]
    # node.right : arr[node<<2|1]
    # node.par   : arr[node<<2|2]
    # node.rev   : arr[node<<2|3]
    self.size : array[int] = array('I', [1] * (self.n+1))
    self.size[-1] = 0
    self.group_cnt = self.n

  def _is_root(self, node: int) -> bool:
    return (self.arr[node<<2|2] == self.n) or not (self.arr[self.arr[node<<2|2]<<2] == node or self.arr[self.arr[node<<2|2]<<2|1] == node)

  def _propagate(self, node: int) -> None:
    if node == self.n: return
    arr = self.arr
    if arr[node<<2|3]:
      arr[node<<2|3] = 0
      ln, rn = arr[node<<2], arr[node<<2|1]
      arr[node<<2] = rn
      arr[node<<2|1] = ln
      arr[ln<<2|3] ^= 1
      arr[rn<<2|3] ^= 1
    lazy, data, key, size = self.lazy, self.data, self.key, self.size
    nlazy = lazy[node]
    lnode, rnode = arr[node<<2], arr[node<<2|1]
    if lnode != self.n:
      data[lnode] += nlazy * size[lnode]
      key[lnode] += nlazy
      lazy[lnode] += nlazy
    if rnode != self.n:
      data[rnode] += nlazy * size[rnode]
      key[rnode] += nlazy
      lazy[rnode] += nlazy
    lazy[node] = 0

  def _update(self, node: int) -> None:
    if node == self.n: return
    ln, rn = self.arr[node<<2], self.arr[node<<2|1]
    self._propagate(ln)
    self._propagate(rn)
    self.data[node] = self.data[ln] + self.key[node] + self.data[rn]
    self.size[node] = 1 + self.size[ln] + self.size[rn]

  def _update_triple(self, x: int, y: int, z: int) -> None:
    self._update(x)
    self._update(y)
    self._update(z)
    return
    data, key, arr, size = self.data, self.key, self.arr, self.size
    lx, rx = arr[x<<2], arr[x<<2|1]
    ly, ry = arr[y<<2], arr[y<<2|1]
    self._propagate(lx)
    self._propagate(rx)
    self._propagate(ly)
    self._propagate(ry)
    data[z] = data[x]
    data[x] = data[lx] + key[x] + data[rx]
    data[y] = data[ly] + key[y] + data[ry]
    size[z] = size[x]
    size[x] = 1 + size[lx] + size[rx]
    size[y] = 1 + size[ly] + size[ry]

  def _update_double(self, x: int, y: int) -> None:
    data, key, arr, size = self.data, self.key, self.arr, self.size
    lx, rx = arr[x<<2], arr[x<<2|1]
    self._propagate(lx)
    self._propagate(rx)
    data[y] = data[x]
    data[x] = data[lx] + key[x] + data[rx]
    size[y] = size[x]
    size[x] = 1 + size[lx] + size[rx]

  def _splay(self, node: int) -> None:
    # splayを抜けた後、nodeは遅延伝播済みにするようにする
    # (splay後のnodeのleft,rightにアクセスしやすいと非常にラクなはず)
    if node == self.n: return
    _propagate, _is_root, _update_triple = self._propagate, self._is_root, self._update_triple
    _propagate(node)
    if _is_root(node): return
    n, arr = self.n, self.arr
    pnode = arr[node<<2|2]
    while not _is_root(pnode):
      gnode = arr[pnode<<2|2]
      _propagate(gnode)
      _propagate(pnode)
      _propagate(node)
      f = arr[pnode<<2] == node
      g = (arr[gnode<<2|f] == pnode) ^ (arr[pnode<<2|f] == node)
      nnode = (node if g else pnode) << 2 | f ^ g
      arr[pnode<<2|f^1] = arr[node<<2|f]
      arr[gnode<<2|f^g^1] = arr[nnode]
      arr[node<<2|f] = pnode
      arr[nnode] = gnode
      arr[node<<2|2] = arr[gnode<<2|2]
      arr[gnode<<2|2] = nnode>>2
      arr[arr[pnode<<2|f^1]<<2|2] = pnode
      arr[arr[gnode<<2|f^g^1]<<2|2] = gnode
      arr[pnode<<2|2] = node
      _update_triple(gnode, pnode, node)
      pnode = arr[node<<2|2]
      if arr[pnode<<2] == gnode:
        arr[pnode<<2] = node
      elif arr[pnode<<2|1] == gnode:
        arr[pnode<<2|1] = node
      else:
        return
    _propagate(pnode)
    _propagate(node)
    f = arr[pnode<<2] == node
    arr[pnode<<2|f^1] = arr[node<<2|f]
    arr[node<<2|f] = pnode
    arr[arr[pnode<<2|f^1]<<2|2] = pnode
    arr[node<<2|2] = arr[pnode<<2|2]
    arr[pnode<<2|2] = node
    self._update_double(pnode, node)

  def expose(self, v: int) -> int:
    ''' vが属する木において、その木の根->vのパスを構築
    '''
    arr, n, _splay, _update = self.arr, self.n, self._splay, self._update
    pre = v
    while arr[v<<2|2] != n:
      _splay(v)
      arr[v<<2|1] = n
      _update(v)
      if arr[v<<2|2] == n:
        break
      pre = arr[v<<2|2]
      _splay(pre)
      arr[pre<<2|1] = v
      _update(pre)
    arr[v<<2|1] = n
    _update(v)
    return pre

  def lca(self, root: int, u: int, v: int) -> int:
    self.evert(root)
    self.expose(u)
    return self.expose(v)

  def link(self, c: int, p: int) -> None:
    ''' c->pの辺を追加する / cは元の木の根でなければならない
    (元の木の根とself._is_root()はまったくの別物)
    '''
    assert not self.same(c, p)
    self.expose(c)
    self.expose(p)
    self.arr[c<<2|2] = p
    self.arr[p<<2|1] = c
    self._update(p)
    self.group_cnt -= 1

  def cut(self, c: int) -> None:
    ''' cとpar[c]の間の辺を削除する / cは元の木の根であってはいけない
    '''
    arr = self.arr
    self.expose(c)
    assert arr[c<<2] != self.n
    arr[arr[c<<2]<<2|2] = self.n
    arr[c<<2] = self.n
    self._update(c)
    self.group_cnt += 1

  def group_count(self) -> int:
    return self.group_cnt

  def root(self, v: int) -> int:
    ''' vが属する木の根を返す
    '''
    self.expose(v)
    arr, n = self.arr, self.n
    while arr[v<<2] != n:
      v = arr[v<<2]
      self._propagate(v)
    self._splay(v)
    return v

  def same(self, u: int, v: int) -> bool:
    ''' uとvが同じ連結成分であるかを返す
    '''
    return self.root(u) == self.root(v)

  def evert(self, v: int) -> None:
    ''' vが属する元の木の根をvにする
    expose→一番右→反転フラグ
    evert後、vは遅延伝播済み(何かと便利なので)
    '''
    self.expose(v)
    self.arr[v<<2|3] ^= 1
    self._propagate(v)

  def prod(self, u: int, v: int) -> T:
    ''' パス[u -> v]間の総積を返す
    非可換に対応
    '''
    self.evert(u)
    self.expose(v)
    return self.data[v]

  def apply(self, u: int, v: int, f: F) -> None:
    self.evert(u)
    self.expose(v)
    self.key[v] += f
    self.data[v] += f * self.size[v]
    self.lazy[v] += f
    self._propagate(v)

  def merge(self, u: int, v: int) -> bool:
    ''' 辺[u - v]を追加する
    '''
    self.evert(u)
    self.expose(v)
    self.arr[u<<2|2] = v
    self.arr[v<<2|1] = u
    self._update(v)
    self.group_cnt -= 1
    return True

  def split(self, u: int, v: int) -> bool:
    ''' 辺[u - v]を削除する
    '''
    if not self.same(v, u): return False
    self.evert(u)
    self.cut(v)
    return True

  def path_kth_elm(self, s: int, t: int, k: int) -> Optional[int]:
    ''' path[s -> t]のk番目を取得する
    '''
    self.evert(s)
    self.expose(t)
    if self.size[t] <= k:
      return None
    size, arr = self.size, self.arr
    while True:
      self._propagate(t)
      s = size[arr[t<<2]]
      if s == k:
        self._splay(t)
        return t
      t = arr[t<<2|(s<k)]
      if s < k:
        k -= s + 1

  def __setitem__(self, k: int, v: T):
    self._splay(k)
    self.key[k] = v
    self._update(k)

  def __getitem__(self, k: int) -> T:
    self._splay(k)
    return self.key[k]

  def __str__(self):
    # 後でやる
    return 'LinkCutTree()'

  def __repr__(self):
    # 後でやる
    return 'LinkCutTree()'


import sys
input = lambda: sys.stdin.buffer.readline().rstrip()

#  -----------------------  #

n = int(input())
lct = LinkCutTree([1]*n)
for _ in range(n-1):
  u, v = map(int, input().split())
  u -= 1
  v -= 1
  lct.merge(u, v)
ans = 0
q = int(input())
for _ in range(q):
  a, b = map(int, input().split())
  a -= 1
  b -= 1
  ans += lct.prod(a, b)
  lct.apply(a, b, 1)
print(ans)
0