結果
問題 | No.901 K-ary εxtrεεmε |
ユーザー |
|
提出日時 | 2025-03-26 21:14:02 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 704 ms / 3,000 ms |
コード長 | 45,090 bytes |
コンパイル時間 | 479 ms |
コンパイル使用メモリ | 86,528 KB |
実行使用メモリ | 147,912 KB |
最終ジャッジ日時 | 2025-03-26 21:14:22 |
合計ジャッジ時間 | 18,534 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 29 |
ソースコード
# verification-helper: PROBLEM https://yukicoder.me/problems/3407 def main(): N = read(int) T = read(AuxTreeWeighted[N,0]) Q = read(int) for _ in range(Q): k, *X = read() V, post = T.tree(X) ans = sum(T.Wa[i] for i in post) write(ans) ''' ╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸ https://kobejean.github.io/cp-library ''' import typing from collections import deque from numbers import Number from types import GenericAlias from typing import Callable, Collection, Iterator, Union import os import sys from io import BytesIO, IOBase class FastIO(IOBase): BUFSIZE = 8192 newlines = 0 def __init__(self, file): self._fd = file.fileno() self.buffer = BytesIO() self.writable = "x" in file.mode or "r" not in file.mode self.write = self.buffer.write if self.writable else None def read(self): BUFSIZE = self.BUFSIZE while True: b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE)) if not b: break ptr = self.buffer.tell() self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr) self.newlines = 0 return self.buffer.read() def readline(self): BUFSIZE = self.BUFSIZE while self.newlines == 0: b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE)) self.newlines = b.count(b"\n") + (not b) ptr = self.buffer.tell() self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr) self.newlines -= 1 return self.buffer.readline() def flush(self): if self.writable: os.write(self._fd, self.buffer.getvalue()) self.buffer.truncate(0), self.buffer.seek(0) class IOWrapper(IOBase): stdin: 'IOWrapper' = None stdout: 'IOWrapper' = None def __init__(self, file): self.buffer = FastIO(file) self.flush = self.buffer.flush self.writable = self.buffer.writable def write(self, s): return self.buffer.write(s.encode("ascii")) def read(self): return self.buffer.read().decode("ascii") def readline(self): return self.buffer.readline().decode("ascii") sys.stdin = IOWrapper.stdin = IOWrapper(sys.stdin) sys.stdout = IOWrapper.stdout = IOWrapper(sys.stdout) from typing import TypeVar _T = TypeVar('T') class TokenStream(Iterator): stream = IOWrapper.stdin def __init__(self): self.queue = deque() def __next__(self): if not self.queue: self.queue.extend(self._line()) return self.queue.popleft() def wait(self): if not self.queue: self.queue.extend(self._line()) while self.queue: yield def _line(self): return TokenStream.stream.readline().split() def line(self): if self.queue: A = list(self.queue) self.queue.clear() return A return self._line() TokenStream.default = TokenStream() class CharStream(TokenStream): def _line(self): return TokenStream.stream.readline().rstrip() CharStream.default = CharStream() ParseFn = Callable[[TokenStream],_T] class Parser: def __init__(self, spec: Union[type[_T],_T]): self.parse = Parser.compile(spec) def __call__(self, ts: TokenStream) -> _T: return self.parse(ts) @staticmethod def compile_type(cls: type[_T], args = ()) -> _T: if issubclass(cls, Parsable): return cls.compile(*args) elif issubclass(cls, (Number, str)): def parse(ts: TokenStream): return cls(next(ts)) return parse elif issubclass(cls, tuple): return Parser.compile_tuple(cls, args) elif issubclass(cls, Collection): return Parser.compile_collection(cls, args) elif callable(cls): def parse(ts: TokenStream): return cls(next(ts)) return parse else: raise NotImplementedError() @staticmethod def compile(spec: Union[type[_T],_T]=int) -> ParseFn[_T]: if isinstance(spec, (type, GenericAlias)): cls = typing.get_origin(spec) or spec args = typing.get_args(spec) or tuple() return Parser.compile_type(cls, args) elif isinstance(offset := spec, Number): cls = type(spec) def parse(ts: TokenStream): return cls(next(ts)) + offset return parse elif isinstance(args := spec, tuple): return Parser.compile_tuple(type(spec), args) elif isinstance(args := spec, Collection): return Parser.compile_collection(type(spec), args) elif isinstance(fn := spec, Callable): def parse(ts: TokenStream): return fn(next(ts)) return parse else: raise NotImplementedError() @staticmethod def compile_line(cls: _T, spec=int) -> ParseFn[_T]: if spec is int: fn = Parser.compile(spec) def parse(ts: TokenStream): return cls([int(token) for token in ts.line()]) return parse else: fn = Parser.compile(spec) def parse(ts: TokenStream): return cls([fn(ts) for _ in ts.wait()]) return parse @staticmethod def compile_repeat(cls: _T, spec, N) -> ParseFn[_T]: fn = Parser.compile(spec) def parse(ts: TokenStream): return cls([fn(ts) for _ in range(N)]) return parse @staticmethod def compile_children(cls: _T, specs) -> ParseFn[_T]: fns = tuple((Parser.compile(spec) for spec in specs)) def parse(ts: TokenStream): return cls([fn(ts) for fn in fns]) return parse @staticmethod def compile_tuple(cls: type[_T], specs) -> ParseFn[_T]: if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...: return Parser.compile_line(cls, specs[0]) else: return Parser.compile_children(cls, specs) @staticmethod def compile_collection(cls, specs): if not specs or len(specs) == 1 or isinstance(specs, set): return Parser.compile_line(cls, *specs) elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)): return Parser.compile_repeat(cls, specs[0], specs[1]) else: raise NotImplementedError() class Parsable: @classmethod def compile(cls): def parser(ts: TokenStream): return cls(next(ts)) return parser import operator from itertools import accumulate from typing import Callable, Iterable, TypeVar def presum(iter: Iterable[_T], func: Callable[[_T,_T],_T] = None, initial: _T = None, step = 1) -> list[_T]: if step == 1: return list(accumulate(iter, func, initial=initial)) else: assert step >= 2 if func is None: func = operator.add A = list(iter) if initial is not None: A = [initial] + A for i in range(step,len(A)): A[i] = func(A[i], A[i-step]) return A def sort2(a, b): return (a,b) if a < b else (b,a) from itertools import pairwise from typing import Any, List class MinSparseTable: def __init__(self, arr: List[Any]): self.N = N = len(arr) self.log = N.bit_length() self.offsets = offsets = [0] for i in range(1, self.log): offsets.append(offsets[-1] + N - (1 << (i-1)) + 1) self.st = st = [0] * (offsets[-1] + N - (1 << (self.log-1)) + 1) st[:N] = arr for i,ni in pairwise(range(self.log)): start, nxt, d = offsets[i], offsets[ni], 1 << i for j in range(N - (1 << ni) + 1): st[nxt+j] = min(st[k := start+j], st[k + d]) def query(self, l: int, r: int) -> Any: k = (r-l).bit_length() - 1 start, st = self.offsets[k], self.st return min(st[start + l], st[start + r - (1 << k)]) def __repr__(self) -> str: rows, offsets, log, st = [], self.offsets, self.log, self.st for i in range(log): start = offsets[i] end = offsets[i+1] if i+1 < log else len(st) rows.append(f"{i:<2d} {st[start:end]}") return '\n'.join(rows) class LCATable(MinSparseTable): def __init__(lca, T, root = 0): N = len(T) T.euler_tour(root) lca.depth = depth = presum(T.delta) lca.tin, lca.tout = T.tin[:], T.tout[:] lca.mask = (1 << (shift := N.bit_length()))-1 lca.shift = shift order = T.order M = len(order) packets = [0]*M for i in range(M): packets[i] = depth[i] << shift | order[i] super().__init__(packets) def _query(lca, u, v): tin = lca.tin l, r = sort2(tin[u], tin[v]); r += 1 da = super().query(l, r) return l, r, da & lca.mask, da >> lca.shift def query(lca, u, v) -> tuple[int,int]: l, r, a, d = lca._query(u, v) return a, d def distance(lca, u, v) -> int: l, r, a, d = lca._query(u, v) return lca.depth[l] + lca.depth[r-1] - 2*d def path(lca, u, v): path, par, lca, c = [], lca.T.par, lca.query(u, v)[0], u while c != lca: path.append(c) c = par[c] path.append(lca) rev_path, c = [], v while c != lca: rev_path.append(c) c = par[c] path.extend(reversed(rev_path)) return path class LCATableWeighted(LCATable): def __init__(lca, T, root = 0): super().__init__(T, root) lca.weights = T.Wdelta lca.weighted_depth = None def distance(lca, u, v) -> int: if lca.weighted_depth is None: lca.weighted_depth = presum(lca.weights) l, r, a, _ = lca._query(u, v) m = lca.tin[a] return lca.weighted_depth[l] + lca.weighted_depth[r-1] - 2*lca.weighted_depth[m] def chmin(dp, i, v): if ch:=dp[i]>v:dp[i]=v return ch from typing import overload def pack_sm(N: int): s = N.bit_length() return s, (1<<s)-1 def pack_enc(a: int, b: int, s: int): return a << s | b def pack_dec(ab: int, s: int, m: int): return ab >> s, ab & m def pack_indices(A, s): return [a << s | i for i,a in enumerate(A)] def argsort(A: list[int], reverse=False): s, m = pack_sm(len(A)) if reverse: I = [a<<s|i^m for i,a in enumerate(A)] I.sort(reverse=True) for i,ai in enumerate(I): I[i] = (ai^m)&m else: I = [a<<s|i for i,a in enumerate(A)] I.sort() for i,ai in enumerate(I): I[i] = ai&m return I from math import inf from typing import Callable, Sequence, Union, overload from enum import auto, IntFlag, IntEnum class DFSFlags(IntFlag): ENTER = auto() DOWN = auto() BACK = auto() CROSS = auto() LEAVE = auto() UP = auto() MAXDEPTH = auto() RETURN_PARENTS = auto() RETURN_DEPTHS = auto() BACKTRACK = auto() CONNECT_ROOTS = auto() # Common combinations ALL_EDGES = DOWN | BACK | CROSS EULER_TOUR = DOWN | UP INTERVAL = ENTER | LEAVE TOPDOWN = DOWN | CONNECT_ROOTS BOTTOMUP = UP | CONNECT_ROOTS RETURN_ALL = RETURN_PARENTS | RETURN_DEPTHS class DFSEvent(IntEnum): ENTER = DFSFlags.ENTER DOWN = DFSFlags.DOWN BACK = DFSFlags.BACK CROSS = DFSFlags.CROSS LEAVE = DFSFlags.LEAVE UP = DFSFlags.UP MAXDEPTH = DFSFlags.MAXDEPTH class GraphBase(Sequence, Parsable): def __init__(G, N: int, M: int, U: list[int], V: list[int], deg: list[int], La: list[int], Ra: list[int], Ua: list[int], Va: list[int], Ea: list[int], twin: list[int] = None): G.N = N '''The number of vertices.''' G.M = M '''The number of edges.''' G.U = U '''A list of source vertices in the original edge list.''' G.V = V '''A list of destination vertices in the original edge list.''' G.deg = deg '''deg[u] is the out degree of vertex u.''' G.La = La '''La[u] stores the start index of the list of adjacent vertices from u.''' G.Ra = Ra '''Ra[u] stores the stop index of the list of adjacent vertices from u.''' G.Ua = Ua '''Ua[i] = u for La[u] <= i < Ra[u], useful for backtracking.''' G.Va = Va '''Va[i] lists adjacent vertices to u for La[u] <= i < Ra[u].''' G.Ea = Ea '''Ea[i] lists the edge ids that start from u for La[u] <= i < Ra[u]. For undirected graphs, edge ids in range M<= e <2*M are edges from V[e-M] -> U[e-M]. ''' G.twin = twin if twin is not None else range(len(Ua)) '''twin[i] in undirected graphs stores index j of the same edge but with u and v swapped.''' G.st: list[int] = None G.order: list[int] = None G.vis: list[int] = None G.back: list[int] = None G.tin: list[int] = None def prep_vis(G): if G.vis is None: G.vis = u8f(G.N) return G.vis def prep_st(G): if G.st is None: G.st = elist(G.N) else: G.st.clear() return G.st def prep_order(G): if G.order is None: G.order = elist(G.N) else: G.order.clear() return G.order def prep_back(G): if G.back is None: G.back = i32f(G.N, -2) return G.back def prep_tin(G): if G.tin is None: G.tin = i32f(G.N, -1) return G.tin def __len__(G) -> int: return G.N def __getitem__(G, u): return G.Va[G.La[u]:G.Ra[u]] def range(G, u): return range(G.La[u],G.Ra[u]) @overload def distance(G) -> list[list[int]]: ... @overload def distance(G, s: int = 0) -> list[int]: ... @overload def distance(G, s: int, g: int) -> int: ... def distance(G, s = None, g = None): if s == None: return G.floyd_warshall() else: return G.bfs(s, g) def recover_path(G, s, t): Ua, back, vertices = G.Ua, G.back, u32f(1, v := t) while v != s: vertices.append(v := Ua[back[v]]) return vertices def recover_path_edge_ids(G, s, t): Ea, Ua, back, edges, v = G.Ea, G.Ua, G.back, u32f(0), t while v != s: edges.append(Ea[i := back[v]]), (v := Ua[i]) return edges def shortest_path(G, s: int, t: int): if G.distance(s, t) >= inf: return None vertices = G.recover_path(s, t) vertices.reverse() return vertices def shortest_path_edge_ids(G, s: int, t: int): if G.distance(s, t) >= inf: return None edges = G.recover_path_edge_ids(s, t) edges.reverse() return edges @overload def bfs(G, s: Union[int,list] = 0) -> list[int]: ... @overload def bfs(G, s: Union[int,list], g: int) -> int: ... def bfs(G, s: int = 0, g: int = None): S, Va, back, D = G.starts(s), G.Va, i32f(N := G.N, -1), [inf]*N G.back, G.D = back, D for u in S: D[u] = 0 que = deque(S) while que: nd = D[u := que.popleft()]+1 if u == g: return nd-1 for i in G.range(u): if nd < D[v := Va[i]]: D[v], back[v] = nd, i que.append(v) return D if g is None else inf def floyd_warshall(G) -> list[list[int]]: G.D = D = [[inf]*G.N for _ in range(G.N)] for u in range(G.N): D[u][u] = 0 for i in range(len(G.Ua)): D[G.Ua[i]][G.Va[i]] = 1 for k, Dk in enumerate(D): for Di in D: if (Dik := Di[k]) == inf: continue for j in range(G.N): chmin(Di, j, Dik+Dk[j]) return D def find_cycle_indices(G, s: Union[int, None] = None): Ea, Ua, Va, vis, back = G.Ea, G. Ua, G.Va, u8f(N := G.N), u32f(N, i32_max) G.vis, G.back, st = vis, back, elist(N) for s in G.starts(s): if vis[s]: continue st.append(s) while st: if not vis[u := st.pop()]: st.append(u) vis[u], pe = 1, Ea[j] if (j := back[u]) != i32_max else i32_max for i in G.range(u): if not vis[v := Va[i]]: back[v] = i st.append(v) elif vis[v] == 1 and pe != Ea[i]: I = u32f(1,i) while v != u: I.append(i := back[u]), (u := Ua[i]) I.reverse() return I else: vis[u] = 2 # check for self loops for i in range(len(Ua)): if Ua[i] == Va[i]: return u32f(1,i) def find_cycle(G, s: Union[int, None] = None): if I := G.find_cycle_indices(s): return [G.Ua[i] for i in I] def find_cycle_edge_ids(G, s: Union[int, None] = None): if I := G.find_cycle_indices(s): return [G.Ea[i] for i in I] def find_minimal_cycle(G, s=0): D, par, que, Va = u32f(N := G.N, u32_max), i32f(N, -1), deque([s]), G.Va D[s] = 0 while que: for i in G.range(u := que.popleft()): if (v := Va[i]) == s: # Found cycle back to start cycle = [u] while u != s: cycle.append(u := par[u]) return cycle if D[v] < u32_max: continue D[v], par[v] = D[u]+1, u; que.append(v) def dfs_topdown(G, s: Union[int,list] = None) -> list[int]: '''Returns lists of indices i where Ua[i] -> Va[i] are edges in order of top down discovery''' vis, st, order = G.prep_vis(), G.prep_st(), G.prep_order() for s in G.starts(s): if vis[s]: continue vis[s] = 1; st.append(s) while st: for i in G.range(st.pop()): if vis[v := G.Va[i]]: continue vis[v] = 1; order.append(i); st.append(v) return order def dfs(G, s: Union[int,list] = None, /, backtrack = False, max_depth = None, enter_fn: Callable[[int],None] = None, leave_fn: Callable[[int],None] = None, max_depth_fn: Callable[[int],None] = None, down_fn: Callable[[int,int,int],None] = None, back_fn: Callable[[int,int,int],None] = None, forward_fn: Callable[[int,int,int],None] = None, cross_fn: Callable[[int,int,int],None] = None, up_fn: Callable[[int,int,int],None] = None): I, time, vis, st, back, tin = G.La[:], -1, G.prep_vis(), G.prep_st(), G.prep_back(), G.prep_tin() for s in G.starts(s): if vis[s]: continue back[s], tin[s] = -1, (time := time+1); st.append(s) while st: if vis[u := st[-1]] == 0: vis[u] = 1 if enter_fn: enter_fn(u) if max_depth is not None and len(st) > max_depth: I[u] = G.Ra[u] if max_depth_fn: max_depth_fn(u) if (i := I[u]) < G.Ra[u]: I[u] += 1 if (s := vis[v := G.Va[i]]) == 0: back[v], tin[v] = i, (time := time+1); st.append(v) if down_fn: down_fn(u,v,i) elif back_fn and s == 1 and back[u] != G.twin[i]: back_fn(u,v,i) elif (cross_fn or forward_fn) and s == 2: if forward_fn and tin[u] < tin[v]: forward_fn(u,v,i) elif cross_fn: cross_fn(u,v,i) else: vis[u] = 2; st.pop() if backtrack: vis[u], I[u] = 0, G.La[u] if leave_fn: leave_fn(u) if up_fn and st: up_fn(u, st[-1], back[u]) def dfs_enter_leave(G, s: Union[int,list[int],None] = None) -> Sequence[tuple[DFSEvent,int]]: N, I = G.N, G.La[:] st, back, plst = elist(N), i32f(N,-2), PacketList(order := elist(2*N), N-1) G.back, ENTER, LEAVE = back, int(DFSEvent.ENTER) << plst.shift, int(DFSEvent.LEAVE) << plst.shift for s in G.starts(s): if back[s] >= -1: continue back[s] = -1 order.append(ENTER | s), st.append(s) while st: if (i := I[u := st[-1]]) < G.Ra[u]: I[u] += 1 if back[v := G.Va[i]] >= -1: continue back[v] = i; order.append(ENTER | v); st.append(v) else: order.append(LEAVE | u); st.pop() return plst def starts(G, s: Union[int,list[int],None]) -> list[int]: if isinstance(s, int): return [s] elif s is None: return range(G.N) elif isinstance(s, list): return s else: return list(s) @classmethod def compile(cls, N: int, M: int, shift: int = -1): def parse(ts: TokenStream): U, V = u32f(M), u32f(M) for i in range(M): u, v = ts._line() U[i], V[i] = int(u)+shift, int(v)+shift return cls(N, U, V) return parse def elist(est_len: int) -> list: ... try: from __pypy__ import newlist_hint except: def newlist_hint(hint): return [] elist = newlist_hint from array import array def i8f(N: int, elm: int = 0): return array('b', (elm,))*N # signed char def u8f(N: int, elm: int = 0): return array('B', (elm,))*N # unsigned char def i16f(N: int, elm: int = 0): return array('h', (elm,))*N # signed short def u16f(N: int, elm: int = 0): return array('H', (elm,))*N # unsigned short def i32f(N: int, elm: int = 0): return array('i', (elm,))*N # signed int def u32f(N: int, elm: int = 0): return array('I', (elm,))*N # unsigned int def i64f(N: int, elm: int = 0): return array('q', (elm,))*N # signed long long # def u64f(N: int, elm: int = 0): return array('Q', (elm,))*N # unsigned long long def f32f(N: int, elm: float = 0.0): return array('f', (elm,))*N # float def f64f(N: int, elm: float = 0.0): return array('d', (elm,))*N # double def i8a(init = None): return array('b') if init is None else array('b', init) # signed char def u8a(init = None): return array('B') if init is None else array('B', init) # unsigned char def i16a(init = None): return array('h') if init is None else array('h', init) # signed short def u16a(init = None): return array('H') if init is None else array('H', init) # unsigned short def i32a(init = None): return array('i') if init is None else array('i', init) # signed int def u32a(init = None): return array('I') if init is None else array('I', init) # unsigned int def i64a(init = None): return array('q') if init is None else array('q', init) # signed long long # def u64a(init = None): return array('Q') if init is None else array('Q', init) # unsigned long long def f32a(init = None): return array('f') if init is None else array('f', init) # float def f64a(init = None): return array('d') if init is None else array('d', init) # double i8_max = (1 << 7)-1 u8_max = (1 << 8)-1 i16_max = (1 << 15)-1 u16_max = (1 << 16)-1 i32_max = (1 << 31)-1 u32_max = (1 << 32)-1 i64_max = (1 << 63)-1 u64_max = (1 << 64)-1 class PacketList(Sequence[tuple[int,int]]): def __init__(lst, A: list[int], max1: int): lst.A = A lst.mask = (1 << (shift := (max1).bit_length())) - 1 lst.shift = shift def __len__(lst): return lst.A.__len__() def __contains__(lst, x: tuple[int,int]): return lst.A.__contains__(x[0] << lst.shift | x[1]) def __getitem__(lst, key) -> tuple[int,int]: x = lst.A[key] return x >> lst.shift, x & lst.mask class GraphWeightedBase(GraphBase): def __init__(self, N: int, M: int, U: list[int], V: list[int], W: list[int], deg: list[int], La: list[int], Ra: list[int], Ua: list[int], Va: list[int], Wa: list[int], Ea: list[int], twin: list[int] = None): super().__init__(N, M, U, V, deg, La, Ra, Ua, Va, Ea, twin) self.W = W self.Wa = Wa '''Wa[i] lists weights to edges from u for La[u] <= i < Ra[u].''' def __getitem__(G, u): l,r = G.La[u],G.Ra[u] return zip(G.Va[l:r], G.Wa[l:r]) @overload def distance(G) -> list[list[int]]: ... @overload def distance(G, s: int = 0) -> list[int]: ... @overload def distance(G, s: int, g: int) -> int: ... def distance(G, s = None, g = None): if s == None: return G.floyd_warshall() else: return G.dijkstra(s, g) def dijkstra(G, s: int, t: int = None): N, S, Va, Wa = G.N, G.starts(s), G.Va, G.Wa G.back, G.D = back, D = i32f(N, -1), [inf]*N for s in S: D[s] = 0 que = PriorityQueue(N, S) while que: u, d = que.pop() if d > D[u]: continue if u == t: return d for i in G.range(u): if chmin(D, v := Va[i], nd := d + Wa[i]): back[v] = i que.push(v, nd) return D if t is None else inf def kruskal(G): U, V, W, dsu, MST, need = G.U, G.V, G.W, DSU(N := G.N), [0]*(N-1), N-1 for e in argsort(W): u, v = dsu.merge(U[e],V[e],True) if u != v: MST[need := need-1] = e if not need: break return None if need else MST def kruskal_heap(G): N, M, U, V, W = G.N, G.M, G.U, G.V, G.W que = PriorityQueue(M, list(range(M)), W) dsu = DSU(N) MST = [0]*(N-1) need = N-1 while que and need: e, _ = que.pop() u, v = dsu.merge(U[e],V[e],True) if u != v: MST[need := need-1] = e return None if need else MST def bellman_ford(G, s: int = 0) -> list[int]: Ua, Va, Wa, D = G.Ua, G.Va, G.Wa, [inf]*(N := G.N) D[s] = 0 for _ in range(N-1): for i, u in enumerate(Ua): if D[u] < inf: chmin(D, Va[i], D[u] + Wa[i]) return D def bellman_ford_neg_cyc_check(G, s: int = 0) -> tuple[bool, list[int]]: M, U, V, W, D = G.M, G.U, G.V, G.W, G.bellman_ford(s) neg_cycle = any(D[U[i]]+W[i]<D[V[i]] for i in range(M) if D[U[i]] < inf) return neg_cycle, D def floyd_warshall(G) -> list[list[int]]: N, Ua, Va, Wa = G.N, G.Ua, G.Va, G.Wa D = [[inf]*N for _ in range(N)] for u in range(N): D[u][u] = 0 for i in range(len(Ua)): chmin(D[Ua[i]], Va[i], Wa[i]) for k, Dk in enumerate(D): for Di in D: if Di[k] >= inf: continue for j in range(N): if Dk[j] >= inf: continue chmin(Di, j, Di[k]+Dk[j]) return D def floyd_warshall_neg_cyc_check(G): D = G.floyd_warshall() return any(D[i][i] < 0 for i in range(G.N)), D @classmethod def compile(cls, N: int, M: int, shift: int = -1): def parse(ts: TokenStream): U, V, W = u32f(M), u32f(M), [0]*M for i in range(M): u, v, w = ts._line() U[i], V[i], W[i] = int(u)+shift, int(v)+shift, int(w) return cls(N, U, V, W) return parse class DSU: def __init__(self, N): self.N = N self.par = [-1] * N def merge(self, u, v, src = False): assert 0 <= u < self.N assert 0 <= v < self.N x, y = self.leader(u), self.leader(v) if x == y: return (x,y) if src else x if self.par[x] > self.par[y]: x, y = y, x self.par[x] += self.par[y] self.par[y] = x return (x,y) if src else x def same(self, u: int, v: int): assert 0 <= u < self.N assert 0 <= v < self.N return self.leader(u) == self.leader(v) def leader(self, i) -> int: assert 0 <= i < self.N par = self.par p = par[i] while p >= 0: if par[p] < 0: return p par[i], i, p = par[p], par[p], par[par[p]] return i def size(self, i) -> int: assert 0 <= i < self.N return -self.par[self.leader(i)] def groups(self) -> list[list[int]]: leader_buf = [self.leader(i) for i in range(self.N)] result = [[] for _ in range(self.N)] for i in range(self.N): result[leader_buf[i]].append(i) return [r for r in result if r] from collections import UserList from heapq import heapify, heappop, heappush, heappushpop, heapreplace from typing import Generic class HeapProtocol(Generic[_T]): def pop(self) -> _T: ... def push(self, item: _T): ... def pushpop(self, item: _T) -> _T: ... def replace(self, item: _T) -> _T: ... class PriorityQueue(HeapProtocol[int], UserList[int]): def __init__(self, N: int, ids: list[int] = None, priorities: list[int] = None, /): self.shift = N.bit_length() self.mask = (1 << self.shift)-1 if ids is None: self.data = elist(N) elif priorities is None: heapify(ids) self.data = ids else: M = len(ids) data = [0]*M for i in range(M): data[i] = self.encode(ids[i], priorities[i]) heapify(data) self.data = data def encode(self, id, priority): return priority << self.shift | id def decode(self, encoded): return self.mask & encoded, encoded >> self.shift def pop(self): return self.decode(heappop(self.data)) def push(self, id: int, priority: int): heappush(self.data, self.encode(id, priority)) def pushpop(self, id: int, priority: int): return self.decode(heappushpop(self.data, self.encode(id, priority))) def replace(self, id: int, priority: int): return self.decode(heapreplace(self.data, self.encode(id, priority))) class GraphWeighted(GraphWeightedBase): def __init__(G, N: int, U: list[int], V: list[int], W: list[int]): Ma, deg = 0, u32f(N) for e in range(M := len(U)): distinct = (u := U[e]) != (v := V[e]) deg[u] += 1; deg[v] += distinct; Ma += 1+distinct twin, Ea, Ua, Va, Wa = u32f(Ma), u32f(Ma), u32f(Ma), u32f(Ma), [0]*Ma La, i = u32f(N), 0 for u,d in enumerate(deg): La[u], i = i, i + d Ra = La[:] for e in range(M): u, v, w = U[e], V[e], W[e] i, j = Ra[u], Ra[v] Ra[u],Ua[i],Va[i],Wa[i],Ea[i],twin[i] = i+1,u,v,w,e,j if i == j: continue # don't add self loops twice Ra[v],Ua[j],Va[j],Wa[j],Ea[j],twin[j] = j+1,v,u,w,e,i super().__init__(N, M, U, V, W, deg, La, Ra, Ua, Va, Wa, Ea, twin) from typing import Optional from typing import Callable, Literal, TypeVar, Union, overload class TreeBase(GraphBase): @overload def distance(T) -> list[list[int]]: ... @overload def distance(T, s: int = 0) -> list[int]: ... @overload def distance(T, s: int, g: int) -> int: ... def distance(T, s = None, g = None): if s == None: return [T.dfs_distance(u) for u in range(T.N)] else: return T.dfs_distance(s, g) @overload def diameter(T) -> int: ... @overload def diameter(T, endpoints: Literal[True]) -> tuple[int,int,int]: ... def diameter(T, endpoints = False): mask = (1 << (shift := T.N.bit_length())) - 1 s = max(d << shift | v for v,d in enumerate(T.distance(0))) & mask dg = max(d << shift | v for v,d in enumerate(T.distance(s))) diam, g = dg >> shift, dg & mask return (diam, s, g) if endpoints else diam def dfs_distance(T, s: int, g: Union[int,None] = None): st, Va = elist(N := T.N), T.Va T.D, T.back = D, back = [inf]*N, i32f(N, -1) D[s] = 0 st.append(s) while st: nd = D[u := st.pop()]+1 if u == g: return nd-1 for i in T.range(u): if nd < D[v := Va[i]]: D[v], back[v] = nd, i st.append(v) return D if g is None else inf def rerooting_dp(T, e: _T, merge: Callable[[_T,_T],_T], edge_op: Callable[[_T,int,int,int],_T] = lambda s,i,p,u:s, s: int = 0): La, Ua, Va = T.La, T.Ua, T.Va order, dp, suf, I = T.dfs_topdown(s), [e]*T.N, [e]*len(Ua), T.Ra[:] # up for i in order[::-1]: u,v = Ua[i], Va[i] # subtree v finished up pass, store value to accumulate for u dp[v] = new = edge_op(dp[v], i, u, v) dp[u] = merge(dp[u], new) # suffix accumulation if (c:=I[u]-1) > La[u]: suf[c-1] = merge(suf[c], new) I[u] = c # down dp[s] = e # at this point dp stores values to be merged in parent for i in order: u,v = Ua[i], Va[i] dp[u] = merge(pre := dp[u], dp[v]) dp[v] = edge_op(merge(suf[I[u]], pre), i, v, u) I[u] += 1 return dp def euler_tour(T, s = 0): N, Va = len(T), T.Va tin, tout, par, back = [-1]*N,[-1]*N,[-1]*N,[0]*N order, delta = elist(2*N), elist(2*N) st = elist(N); st.append(s) while st: p = par[u := st.pop()] if tin[u] == -1: tin[u] = len(order) for i in T.range(u): if (v := Va[i]) != p: par[v], back[v] = u, i st.append(u); st.append(v) delta.append(1) else: delta.append(-1) order.append(u) tout[u] = len(order) delta[0] = delta[-1] = 0 T.tin, T.tout, T.par, T.back = tin, tout, par, back T.order, T.delta = order, delta def hld_precomp(T, r = 0): N, time, Va = T.N, 0, T.Va tin, tout, size = [0]*N, [0]*N, [1]*N+[0] par, heavy, head = [-1]*N, [-1]*N, [r]*N depth, order, vis = [0]*N, [0]*N, [0]*N st = elist(N) st.append(r) while st: if (s := vis[v := st.pop()]) == 0: # dfs down p, vis[v] = par[v], 1; st.append(v) for i in T.range(v): if (c := Va[i]) != p: depth[c], par[c] = depth[v]+1, v; st.append(c) elif s == 1: # dfs up p, l = par[v], -1 for i in T.range(v): if (c := Va[i]) != p: size[v] += size[c] if size[c] > size[l]: l = c heavy[v] = l if p == -1: vis[v] = 2 st.append(v) elif s == 2: # decompose down p, h, l = par[v], head[v], heavy[v] tin[v], order[time], vis[v] = time, v, 3 time += 1 st.append(v) for i in T.range(v): if (c := Va[i]) != p and c != l: head[c], vis[c] = c, 2 st.append(c) if l != -1: head[l], vis[l] = h, 2 st.append(l) elif s == 3: # decompose up tout[v] = time T.size, T.depth = size, depth T.order, T.tin, T.tout = order, tin, tout T.par, T.heavy, T.head = par, heavy, head @classmethod def compile(cls, N: int, shift: int = -1): return GraphBase.compile.__func__(cls, N, N-1, shift) class TreeWeightedBase(TreeBase, GraphWeightedBase): def dfs_distance(T, s: int, g: Optional[int] = None): st, Wa, Va = elist(N := T.N), T.Wa, T.Va T.D, T.back = D, back = [inf]*N, i32f(N, -1) D[s] = 0; st.append(s) while st: d = D[u := st.pop()] if u == g: return d for i in T.range(u): if (nd := d+Wa[i]) < D[v := Va[i]]: D[v], back[v] = nd, i; st.append(v) return D if g is None else inf def euler_tour(T, s = 0): N, Va, Wa = len(T), T.Va, T.Wa tin, tout, par = [-1]*N,[-1]*N,[-1]*N order, delta, Wdelta = elist(2*N), elist(2*N), elist(2*N) st, Wst = elist(N), elist(N) st.append(s); Wst.append(0) while st: p, wd = par[u := st.pop()], Wst.pop() if tin[u] == -1: tin[u] = len(order) for i in T.range(u): if (v := Va[i]) != p: w, par[v] = Wa[i], u st.append(u); st.append(v); Wst.append(-w); Wst.append(w) delta.append(1) else: delta.append(-1) Wdelta.append(wd); order.append(u) tout[u] = len(order) delta[0] = delta[-1] = 0 T.tin, T.tout, T.par = tin, tout, par T.order, T.delta, T.Wdelta = order, delta, Wdelta def hld_precomp(T, r = 0): N, time, Va, Wa = T.N, 0, T.Va, T.Wa tin, tout, size = [0]*N, [0]*N, [1]*N+[0] par, heavy, head = [-1]*N, [-1]*N, [r]*N depth, order, vis = [0]*N, [0]*N, [0]*N Wpar = [0]*N st = elist(N) st.append(r) while st: if (s := vis[v := st.pop()]) == 0: # dfs down p, vis[v] = par[v], 1 st.append(v) for i in T.range(v): if (c := Va[i]) != p: depth[c], par[c], Wpar[c] = depth[v]+1, v, Wa[i] st.append(c) elif s == 1: # dfs up p, l = par[v], -1 for i in T.range(v): if (c := Va[i]) != p: size[v] += size[c] if size[c] > size[l]: l = c heavy[v] = l if p == -1: vis[v] = 2 st.append(v) elif s == 2: # decompose down p, h, l = par[v], head[v], heavy[v] tin[v], order[time], vis[v] = time, v, 3 time += 1 st.append(v) for i in T.range(v): if (c := Va[i]) != p and c != l: head[c], vis[c] = c, 2 st.append(c) if l != -1: head[l], vis[l] = h, 2 st.append(l) elif s == 3: # decompose up tout[v] = time T.size, T.depth = size, depth T.order, T.tin, T.tout = order, tin, tout T.par, T.heavy, T.head = par, heavy, head T.Wpar = Wpar @classmethod def compile(cls, N: int, shift: int = -1): return GraphWeightedBase.compile.__func__(cls, N, N-1, shift) class TreeWeighted(TreeWeightedBase, GraphWeighted): pass class AuxTreeBase(TreeWeightedBase): def __init__(T, lca: LCATable): T.lca = lca T.Vset = elist(T.N) T.post = elist(T.N-1) T.Ra = T.La[:] def add(T, u, v): w = T.lca.distance(u,v) i, j = T.Ra[u], T.Ra[v] T.Ua[i], T.Va[i], T.Wa[i], T.twin[i], T.Ra[u] = u, v, w, j, i+1 if i == j: return j T.Ua[j], T.Va[j], T.Wa[j], T.twin[j], T.Ra[v] = v, u, w, i, j+1 return j def tree(T, U: list[int], sort=True): if sort: U = sorted(U, key = T.tin.__getitem__) st = T.prep_st() lca, tin, V, post = T.lca, T.tin, T.Vset, T.post # reset while V: T.Ra[u] = T.La[u := V.pop()] if T.vis: T.vis[u] = 0 post.clear() st.append(U[0]) for j in range(len(U)-1): u, v = U[j], U[j+1] a, _ = lca.query(u, v) if a != u: l = st.pop() while st and tin[t := st[-1]] > tin[a]: V.append(l); post.append(T.add(l, l := st.pop())) if not st or t != a: st.append(a) V.append(l); post.append(T.add(l, a)) st.append(v) l = st.pop() while st: V.append(l); post.append(T.add(l, l := st.pop())) V.append(l) return V, post def trees(T, C: list[int]): lca, N = T.lca, T.N T.Ra, cnt, order = T.La[:], [0]*N, argsort(T.tin) for c in C: cnt[c] += 1 L = [0]*N for i in range(N-1): L[i+1] = L[i]+cnt[i] R, G = L[:], [0]*N for i in order: c = C[i]; G[R[c]] = i; R[c] += 1 st, V, post = elist(N), elist(N), elist(N) La, Ra, tin = T.La, T.Ra, T.tin for c in range(N): l, r = L[c], R[c] if l == r: continue st.append(G[l]) for j in range(l,r-1): u, v = G[j], G[j+1] a, _ = lca.query(u, v) if a != u: l = st.pop() while st and tin[t := st[-1]] > tin[a]: V.append(l); post.append(T.add(l, l := st.pop())) if not st or t != a: st.append(a) V.append(l); post.append(T.add(l, a)) st.append(v) l = st.pop() while st: V.append(l); post.append(T.add(l, l := st.pop())) V.append(l) yield c, V, post while V: Ra[u] = La[u := V.pop()] if T.vis: T.vis[u] = 0 post.clear() def rerooting_dp(T, C: list[int], e: _T, merge: Callable[[_T,_T],_T], edge_op: Callable[[_T,int,int,int,int],_T] = lambda s,i,p,u,c:s): ans, dp, suf, I = [e]*T.N, [e]*T.N, [e]*len(T.Ua), T.La[:] for c, V, post in T.trees(C): r = V[-1] for v in V: I[v] = T.Ra[v] # up for i in post: u,v = T.Ua[i], T.Va[i] # subtree v finished up pass, store value to accumulate for u dp[v] = new = edge_op(dp[v], i, u, v, c) dp[u] = merge(dp[u], new) # suffix accumulation if (j:=I[u]-1) > T.La[u]: suf[j-1] = merge(suf[j], new) I[u] = j # down dp[r] = e # at this point dp stores values to be merged in parent for i in reversed(post): u,v = T.Ua[i], T.Va[i] dp[u] = merge(pre := dp[u], dp[v]) dp[v] = edge_op(merge(suf[I[u]], pre), i, v, u, c) I[u] += 1 # store ans and reset for v in V: if C[v] == c: ans[v] = dp[v] dp[v] = e for i in post: suf[i] = e return ans class AuxTreeWeighted(AuxTreeBase, TreeWeighted): def __init__(T, N, U, V, W, root=0): TreeWeighted.__init__(T, N, U, V, W) AuxTreeBase.__init__(T, LCATableWeighted(T, root)) @classmethod def compile(cls, N: int, shift: int = -1, root=0): M = N-1 def parse(ts: TokenStream): U, V, W = u32f(M), u32f(M), [0]*M for i in range(M): u, v, w = ts._line() U[i], V[i], W[i] = int(u)+shift, int(v)+shift, int(w) return cls(N, U, V, W, root) return parse from typing import Iterable, Type, Union, overload @overload def read() -> Iterable[int]: ... @overload def read(spec: int) -> list[int]: ... @overload def read(spec: Union[Type[_T],_T], char=False) -> _T: ... def read(spec: Union[Type[_T],_T] = None, char=False): if not char and spec is None: return map(int, TokenStream.default.line()) parser: _T = Parser.compile(spec) return parser(CharStream.default if char else TokenStream.default) def write(*args, **kwargs): '''Prints the values to a stream, or to stdout_fast by default.''' sep, file = kwargs.pop("sep", " "), kwargs.pop("file", IOWrapper.stdout) at_start = True for x in args: if not at_start: file.write(sep) file.write(str(x)) at_start = False file.write(kwargs.pop("end", "\n")) if kwargs.pop("flush", False): file.flush() if __name__ == '__main__': main()