結果
| 問題 | No.901 K-ary εxtrεεmε | 
| コンテスト | |
| ユーザー |  | 
| 提出日時 | 2025-03-26 21:14:02 | 
| 言語 | PyPy3 (7.3.15) | 
| 結果 | 
                                AC
                                 
                             | 
| 実行時間 | 704 ms / 3,000 ms | 
| コード長 | 45,090 bytes | 
| コンパイル時間 | 479 ms | 
| コンパイル使用メモリ | 86,528 KB | 
| 実行使用メモリ | 147,912 KB | 
| 最終ジャッジ日時 | 2025-03-26 21:14:22 | 
| 合計ジャッジ時間 | 18,534 ms | 
| ジャッジサーバーID (参考情報) | judge3 / judge1 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 1 | 
| other | AC * 29 | 
ソースコード
# verification-helper: PROBLEM https://yukicoder.me/problems/3407
def main():
    N = read(int)
    T = read(AuxTreeWeighted[N,0])
    Q = read(int)
    for _ in range(Q):
        k, *X = read() 
        V, post = T.tree(X)
        ans = sum(T.Wa[i] for i in post)
        write(ans)
    
'''
╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸
             https://kobejean.github.io/cp-library               
'''
import typing
from collections import deque
from numbers import Number
from types import GenericAlias 
from typing import Callable, Collection, Iterator, Union
import os
import sys
from io import BytesIO, IOBase
class FastIO(IOBase):
    BUFSIZE = 8192
    newlines = 0
    def __init__(self, file):
        self._fd = file.fileno()
        self.buffer = BytesIO()
        self.writable = "x" in file.mode or "r" not in file.mode
        self.write = self.buffer.write if self.writable else None
    def read(self):
        BUFSIZE = self.BUFSIZE
        while True:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            if not b:
                break
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines = 0
        return self.buffer.read()
    def readline(self):
        BUFSIZE = self.BUFSIZE
        while self.newlines == 0:
            b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
            self.newlines = b.count(b"\n") + (not b)
            ptr = self.buffer.tell()
            self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
        self.newlines -= 1
        return self.buffer.readline()
    def flush(self):
        if self.writable:
            os.write(self._fd, self.buffer.getvalue())
            self.buffer.truncate(0), self.buffer.seek(0)
class IOWrapper(IOBase):
    stdin: 'IOWrapper' = None
    stdout: 'IOWrapper' = None
    
    def __init__(self, file):
        self.buffer = FastIO(file)
        self.flush = self.buffer.flush
        self.writable = self.buffer.writable
    def write(self, s):
        return self.buffer.write(s.encode("ascii"))
    
    def read(self):
        return self.buffer.read().decode("ascii")
    
    def readline(self):
        return self.buffer.readline().decode("ascii")
sys.stdin = IOWrapper.stdin = IOWrapper(sys.stdin)
sys.stdout = IOWrapper.stdout = IOWrapper(sys.stdout)
from typing import TypeVar
_T = TypeVar('T')
class TokenStream(Iterator):
    stream = IOWrapper.stdin
    def __init__(self):
        self.queue = deque()
    def __next__(self):
        if not self.queue: self.queue.extend(self._line())
        return self.queue.popleft()
    
    def wait(self):
        if not self.queue: self.queue.extend(self._line())
        while self.queue: yield
 
    def _line(self):
        return TokenStream.stream.readline().split()
    def line(self):
        if self.queue:
            A = list(self.queue)
            self.queue.clear()
            return A
        return self._line()
TokenStream.default = TokenStream()
class CharStream(TokenStream):
    def _line(self):
        return TokenStream.stream.readline().rstrip()
CharStream.default = CharStream()
ParseFn = Callable[[TokenStream],_T]
class Parser:
    def __init__(self, spec: Union[type[_T],_T]):
        self.parse = Parser.compile(spec)
    def __call__(self, ts: TokenStream) -> _T:
        return self.parse(ts)
    
    @staticmethod
    def compile_type(cls: type[_T], args = ()) -> _T:
        if issubclass(cls, Parsable):
            return cls.compile(*args)
        elif issubclass(cls, (Number, str)):
            def parse(ts: TokenStream): return cls(next(ts))              
            return parse
        elif issubclass(cls, tuple):
            return Parser.compile_tuple(cls, args)
        elif issubclass(cls, Collection):
            return Parser.compile_collection(cls, args)
        elif callable(cls):
            def parse(ts: TokenStream):
                return cls(next(ts))              
            return parse
        else:
            raise NotImplementedError()
    
    @staticmethod
    def compile(spec: Union[type[_T],_T]=int) -> ParseFn[_T]:
        if isinstance(spec, (type, GenericAlias)):
            cls = typing.get_origin(spec) or spec
            args = typing.get_args(spec) or tuple()
            return Parser.compile_type(cls, args)
        elif isinstance(offset := spec, Number): 
            cls = type(spec)  
            def parse(ts: TokenStream): return cls(next(ts)) + offset
            return parse
        elif isinstance(args := spec, tuple):      
            return Parser.compile_tuple(type(spec), args)
        elif isinstance(args := spec, Collection):  
            return Parser.compile_collection(type(spec), args)
        elif isinstance(fn := spec, Callable): 
            def parse(ts: TokenStream): return fn(next(ts))
            return parse
        else:
            raise NotImplementedError()
    @staticmethod
    def compile_line(cls: _T, spec=int) -> ParseFn[_T]:
        if spec is int:
            fn = Parser.compile(spec)
            def parse(ts: TokenStream): return cls([int(token) for token in ts.line()])
            return parse
        else:
            fn = Parser.compile(spec)
            def parse(ts: TokenStream): return cls([fn(ts) for _ in ts.wait()])
            return parse
    @staticmethod
    def compile_repeat(cls: _T, spec, N) -> ParseFn[_T]:
        fn = Parser.compile(spec)
        def parse(ts: TokenStream): return cls([fn(ts) for _ in range(N)])
        return parse
    @staticmethod
    def compile_children(cls: _T, specs) -> ParseFn[_T]:
        fns = tuple((Parser.compile(spec) for spec in specs))
        def parse(ts: TokenStream): return cls([fn(ts) for fn in fns])  
        return parse
            
    @staticmethod
    def compile_tuple(cls: type[_T], specs) -> ParseFn[_T]:
        if isinstance(specs, (tuple,list)) and len(specs) == 2 and specs[1] is ...:
            return Parser.compile_line(cls, specs[0])
        else:
            return Parser.compile_children(cls, specs)
    @staticmethod
    def compile_collection(cls, specs):
        if not specs or len(specs) == 1 or isinstance(specs, set):
            return Parser.compile_line(cls, *specs)
        elif (isinstance(specs, (tuple,list)) and len(specs) == 2 and isinstance(specs[1], int)):
            return Parser.compile_repeat(cls, specs[0], specs[1])
        else:
            raise NotImplementedError()
class Parsable:
    @classmethod
    def compile(cls):
        def parser(ts: TokenStream): return cls(next(ts))
        return parser
import operator
from itertools import accumulate
from typing import Callable, Iterable, TypeVar
def presum(iter: Iterable[_T], func: Callable[[_T,_T],_T] = None, initial: _T = None, step = 1) -> list[_T]:
    if step == 1:
        return list(accumulate(iter, func, initial=initial))
    else:
        assert step >= 2
        if func is None:
            func = operator.add
        A = list(iter)
        if initial is not None:
            A = [initial] + A
        for i in range(step,len(A)):
            A[i] = func(A[i], A[i-step])
        return A
def sort2(a, b):
    return (a,b) if a < b else (b,a)
from itertools import pairwise
from typing import Any, List
class MinSparseTable:
    def __init__(self, arr: List[Any]):
        self.N = N = len(arr)
        self.log = N.bit_length()
        
        self.offsets = offsets = [0]
        for i in range(1, self.log):
            offsets.append(offsets[-1] + N - (1 << (i-1)) + 1)
            
        self.st = st = [0] * (offsets[-1] + N - (1 << (self.log-1)) + 1)
        st[:N] = arr 
        
        for i,ni in pairwise(range(self.log)):
            start, nxt, d = offsets[i], offsets[ni], 1 << i
            for j in range(N - (1 << ni) + 1):
                st[nxt+j] = min(st[k := start+j], st[k + d])
    def query(self, l: int, r: int) -> Any:
        k = (r-l).bit_length() - 1
        start, st = self.offsets[k], self.st
        return min(st[start + l], st[start + r - (1 << k)])
    
    def __repr__(self) -> str:
        rows, offsets, log, st = [], self.offsets, self.log, self.st
        for i in range(log):
            start = offsets[i]
            end = offsets[i+1] if i+1 < log else len(st)
            rows.append(f"{i:<2d} {st[start:end]}")
        return '\n'.join(rows)
class LCATable(MinSparseTable):
    def __init__(lca, T, root = 0):
        N = len(T)
        T.euler_tour(root)
        lca.depth = depth = presum(T.delta)
        lca.tin, lca.tout = T.tin[:], T.tout[:]
        lca.mask = (1 << (shift := N.bit_length()))-1
        lca.shift = shift
        order = T.order
        M = len(order)
        packets = [0]*M
        for i in range(M):
            packets[i] = depth[i] << shift | order[i] 
        super().__init__(packets)
    def _query(lca, u, v):
        tin = lca.tin
        l, r = sort2(tin[u], tin[v]); r += 1
        da = super().query(l, r)
        return l, r, da & lca.mask, da >> lca.shift
    def query(lca, u, v) -> tuple[int,int]:
        l, r, a, d = lca._query(u, v)
        return a, d
    
    def distance(lca, u, v) -> int:
        l, r, a, d = lca._query(u, v)
        return lca.depth[l] + lca.depth[r-1] - 2*d
    
    def path(lca, u, v):
        path, par, lca, c = [], lca.T.par, lca.query(u, v)[0], u
        while c != lca:
            path.append(c)
            c = par[c]
        path.append(lca)
        rev_path, c = [], v
        while c != lca:
            rev_path.append(c)
            c = par[c]
        path.extend(reversed(rev_path))
        return path
class LCATableWeighted(LCATable):
    def __init__(lca, T, root = 0):
        super().__init__(T, root)
        lca.weights = T.Wdelta
        lca.weighted_depth = None
    def distance(lca, u, v) -> int:
        if lca.weighted_depth is None:
            lca.weighted_depth = presum(lca.weights)
        l, r, a, _ = lca._query(u, v)
        m = lca.tin[a]
        return lca.weighted_depth[l] + lca.weighted_depth[r-1] - 2*lca.weighted_depth[m]
def chmin(dp, i, v):
    if ch:=dp[i]>v:dp[i]=v
    return ch
from typing import overload
def pack_sm(N: int):
    s = N.bit_length()
    return s, (1<<s)-1
def pack_enc(a: int, b: int, s: int):
    return a << s | b
    
def pack_dec(ab: int, s: int, m: int):
    return ab >> s, ab & m
def pack_indices(A, s):
    return [a << s | i for i,a in enumerate(A)]
def argsort(A: list[int], reverse=False):
    s, m = pack_sm(len(A))
    if reverse:
        I = [a<<s|i^m for i,a in enumerate(A)]
        I.sort(reverse=True)
        for i,ai in enumerate(I): I[i] = (ai^m)&m
    else:
        I = [a<<s|i for i,a in enumerate(A)]
        I.sort()
        for i,ai in enumerate(I): I[i] = ai&m
    return I
from math import inf
from typing import Callable, Sequence, Union, overload
from enum import auto, IntFlag, IntEnum
class DFSFlags(IntFlag):
    ENTER = auto()
    DOWN = auto()
    BACK = auto()
    CROSS = auto()
    LEAVE = auto()
    UP = auto()
    MAXDEPTH = auto()
    RETURN_PARENTS = auto()
    RETURN_DEPTHS = auto()
    BACKTRACK = auto()
    CONNECT_ROOTS = auto()
    # Common combinations
    ALL_EDGES = DOWN | BACK | CROSS
    EULER_TOUR = DOWN | UP
    INTERVAL = ENTER | LEAVE
    TOPDOWN = DOWN | CONNECT_ROOTS
    BOTTOMUP = UP | CONNECT_ROOTS
    RETURN_ALL = RETURN_PARENTS | RETURN_DEPTHS
class DFSEvent(IntEnum):
    ENTER = DFSFlags.ENTER 
    DOWN = DFSFlags.DOWN 
    BACK = DFSFlags.BACK 
    CROSS = DFSFlags.CROSS 
    LEAVE = DFSFlags.LEAVE 
    UP = DFSFlags.UP 
    MAXDEPTH = DFSFlags.MAXDEPTH
    
class GraphBase(Sequence, Parsable):
    def __init__(G, N: int, M: int, U: list[int], V: list[int], 
                 deg: list[int], La: list[int], Ra: list[int],
                 Ua: list[int], Va: list[int], Ea: list[int], twin: list[int] = None):
        G.N = N
        '''The number of vertices.'''
        G.M = M
        '''The number of edges.'''
        G.U = U
        '''A list of source vertices in the original edge list.'''
        G.V = V
        '''A list of destination vertices in the original edge list.'''
        G.deg = deg
        '''deg[u] is the out degree of vertex u.'''
        G.La = La
        '''La[u] stores the start index of the list of adjacent vertices from u.'''
        G.Ra = Ra
        '''Ra[u] stores the stop index of the list of adjacent vertices from u.'''
        G.Ua = Ua
        '''Ua[i] = u for La[u] <= i < Ra[u], useful for backtracking.'''
        G.Va = Va
        '''Va[i] lists adjacent vertices to u for La[u] <= i < Ra[u].'''
        G.Ea = Ea
        '''Ea[i] lists the edge ids that start from u for La[u] <= i < Ra[u].
        For undirected graphs, edge ids in range M<= e <2*M are edges from V[e-M] -> U[e-M].
        '''
        G.twin = twin if twin is not None else range(len(Ua))
        '''twin[i] in undirected graphs stores index j of the same edge but with u and v swapped.'''
        G.st: list[int] = None
        G.order: list[int] = None
        G.vis: list[int] = None
        G.back: list[int] = None
        G.tin: list[int] = None
    def prep_vis(G):
        if G.vis is None: G.vis = u8f(G.N)
        return G.vis
    
    def prep_st(G):
        if G.st is None: G.st = elist(G.N)
        else: G.st.clear()
        return G.st
    
    def prep_order(G):
        if G.order is None: G.order = elist(G.N)
        else: G.order.clear()
        return G.order
    
    def prep_back(G):
        if G.back is None: G.back = i32f(G.N, -2)
        return G.back
    
    def prep_tin(G):
        if G.tin is None: G.tin = i32f(G.N, -1)
        return G.tin
    
    def __len__(G) -> int: return G.N
    def __getitem__(G, u): return G.Va[G.La[u]:G.Ra[u]]
    def range(G, u): return range(G.La[u],G.Ra[u])
    
    @overload
    def distance(G) -> list[list[int]]: ...
    @overload
    def distance(G, s: int = 0) -> list[int]: ...
    @overload
    def distance(G, s: int, g: int) -> int: ...
    def distance(G, s = None, g = None):
        if s == None: return G.floyd_warshall()
        else: return G.bfs(s, g)
    def recover_path(G, s, t):
        Ua, back, vertices = G.Ua, G.back, u32f(1, v := t)
        while v != s: vertices.append(v := Ua[back[v]])
        return vertices
    
    def recover_path_edge_ids(G, s, t):
        Ea, Ua, back, edges, v = G.Ea, G.Ua, G.back, u32f(0), t
        while v != s: edges.append(Ea[i := back[v]]), (v := Ua[i])
        return edges
    def shortest_path(G, s: int, t: int):
        if G.distance(s, t) >= inf: return None
        vertices = G.recover_path(s, t)
        vertices.reverse()
        return vertices
    
    def shortest_path_edge_ids(G, s: int, t: int):
        if G.distance(s, t) >= inf: return None
        edges = G.recover_path_edge_ids(s, t)
        edges.reverse()
        return edges
    
    @overload
    def bfs(G, s: Union[int,list] = 0) -> list[int]: ...
    @overload
    def bfs(G, s: Union[int,list], g: int) -> int: ...
    def bfs(G, s: int = 0, g: int = None):
        S, Va, back, D = G.starts(s), G.Va, i32f(N := G.N, -1), [inf]*N
        G.back, G.D = back, D
        for u in S: D[u] = 0
        que = deque(S)
        while que:
            nd = D[u := que.popleft()]+1
            if u == g: return nd-1
            for i in G.range(u):
                if nd < D[v := Va[i]]:
                    D[v], back[v] = nd, i
                    que.append(v)
        return D if g is None else inf 
    def floyd_warshall(G) -> list[list[int]]:
        G.D = D = [[inf]*G.N for _ in range(G.N)]
        for u in range(G.N): D[u][u] = 0
        for i in range(len(G.Ua)): D[G.Ua[i]][G.Va[i]] = 1
        for k, Dk in enumerate(D):
            for Di in D:
                if (Dik := Di[k]) == inf: continue
                for j in range(G.N):
                    chmin(Di, j, Dik+Dk[j])
        return D
    def find_cycle_indices(G, s: Union[int, None] = None):
        Ea, Ua, Va, vis, back = G.Ea, G. Ua, G.Va, u8f(N := G.N), u32f(N, i32_max)
        G.vis, G.back, st = vis, back, elist(N)
        for s in G.starts(s):
            if vis[s]: continue
            st.append(s)
            while st:
                if not vis[u := st.pop()]:
                    st.append(u)
                    vis[u], pe = 1, Ea[j] if (j := back[u]) != i32_max else i32_max
                    for i in G.range(u):
                        if not vis[v := Va[i]]:
                            back[v] = i
                            st.append(v)
                        elif vis[v] == 1 and pe != Ea[i]:
                            I = u32f(1,i)
                            while v != u: I.append(i := back[u]), (u := Ua[i])
                            I.reverse()
                            return I
                else:
                    vis[u] = 2
        # check for self loops
        for i in range(len(Ua)):
            if Ua[i] == Va[i]:
                return u32f(1,i)
    
    def find_cycle(G, s: Union[int, None] = None):
        if I := G.find_cycle_indices(s): return [G.Ua[i] for i in I]
    
    def find_cycle_edge_ids(G, s: Union[int, None] = None):
        if I := G.find_cycle_indices(s): return [G.Ea[i] for i in I]
    def find_minimal_cycle(G, s=0):
        D, par, que, Va = u32f(N := G.N, u32_max), i32f(N, -1), deque([s]), G.Va
        D[s] = 0
        while que:
            for i in G.range(u := que.popleft()):
                if (v := Va[i]) == s:  # Found cycle back to start
                    cycle = [u]
                    while u != s: cycle.append(u := par[u])
                    return cycle
                if D[v] < u32_max: continue
                D[v], par[v] = D[u]+1, u; que.append(v)
    def dfs_topdown(G, s: Union[int,list] = None) -> list[int]:
        '''Returns lists of indices i where Ua[i] -> Va[i] are edges in order of top down discovery'''
        vis, st, order = G.prep_vis(), G.prep_st(), G.prep_order()
        for s in G.starts(s):
            if vis[s]: continue
            vis[s] = 1; st.append(s) 
            while st:
                for i in G.range(st.pop()):
                    if vis[v := G.Va[i]]: continue
                    vis[v] = 1; order.append(i); st.append(v)
        return order
    def dfs(G, s: Union[int,list] = None, /, 
            backtrack = False,
            max_depth = None,
            enter_fn: Callable[[int],None] = None,
            leave_fn: Callable[[int],None] = None,
            max_depth_fn: Callable[[int],None] = None,
            down_fn: Callable[[int,int,int],None] = None,
            back_fn: Callable[[int,int,int],None] = None,
            forward_fn: Callable[[int,int,int],None] = None,
            cross_fn: Callable[[int,int,int],None] = None,
            up_fn: Callable[[int,int,int],None] = None):
        I, time, vis, st, back, tin = G.La[:], -1, G.prep_vis(), G.prep_st(), G.prep_back(), G.prep_tin()
        for s in G.starts(s):
            if vis[s]: continue
            back[s], tin[s] = -1, (time := time+1); st.append(s)
            while st:
                if vis[u := st[-1]] == 0:
                    vis[u] = 1
                    if enter_fn: enter_fn(u)
                    if max_depth is not None and len(st) > max_depth:
                        I[u] = G.Ra[u]
                        if max_depth_fn: max_depth_fn(u)
                if (i := I[u]) < G.Ra[u]:
                    I[u] += 1
                    if (s := vis[v := G.Va[i]]) == 0:
                        back[v], tin[v] = i, (time := time+1); st.append(v)
                        if down_fn: down_fn(u,v,i)
                    elif back_fn and s == 1 and back[u] != G.twin[i]: back_fn(u,v,i)
                    elif (cross_fn or forward_fn) and s == 2:
                        if forward_fn and tin[u] < tin[v]: forward_fn(u,v,i)
                        elif cross_fn: cross_fn(u,v,i)
                else:
                    vis[u] = 2; st.pop()
                    if backtrack: vis[u], I[u] = 0, G.La[u]
                    if leave_fn: leave_fn(u)
                    if up_fn and st: up_fn(u, st[-1], back[u])
    
    def dfs_enter_leave(G, s: Union[int,list[int],None] = None) -> Sequence[tuple[DFSEvent,int]]:
        N, I = G.N, G.La[:]
        st, back, plst = elist(N), i32f(N,-2), PacketList(order := elist(2*N), N-1)
        G.back, ENTER, LEAVE = back, int(DFSEvent.ENTER) << plst.shift, int(DFSEvent.LEAVE) << plst.shift
        for s in G.starts(s):
            if back[s] >= -1: continue
            back[s] = -1
            order.append(ENTER | s), st.append(s)
            while st:
                if (i := I[u := st[-1]]) < G.Ra[u]:
                    I[u] += 1
                    if back[v := G.Va[i]] >= -1: continue
                    back[v] = i; order.append(ENTER | v); st.append(v)
                else:
                    order.append(LEAVE | u); st.pop()
        return plst
    
    def starts(G, s: Union[int,list[int],None]) -> list[int]:
        if isinstance(s, int): return [s]
        elif s is None: return range(G.N)
        elif isinstance(s, list): return s
        else: return list(s)
    @classmethod
    def compile(cls, N: int, M: int, shift: int = -1):
        def parse(ts: TokenStream):
            U, V = u32f(M), u32f(M)
            for i in range(M):
                u, v = ts._line()
                U[i], V[i] = int(u)+shift, int(v)+shift
            return cls(N, U, V)
        return parse
    
def elist(est_len: int) -> list: ...
try:
    from __pypy__ import newlist_hint
except:
    def newlist_hint(hint):
        return []
elist = newlist_hint
    
from array import array
def i8f(N: int, elm: int = 0):      return array('b', (elm,))*N  # signed char
def u8f(N: int, elm: int = 0):      return array('B', (elm,))*N  # unsigned char
def i16f(N: int, elm: int = 0):     return array('h', (elm,))*N  # signed short
def u16f(N: int, elm: int = 0):     return array('H', (elm,))*N  # unsigned short
def i32f(N: int, elm: int = 0):     return array('i', (elm,))*N  # signed int
def u32f(N: int, elm: int = 0):     return array('I', (elm,))*N  # unsigned int
def i64f(N: int, elm: int = 0):     return array('q', (elm,))*N  # signed long long
# def u64f(N: int, elm: int = 0):     return array('Q', (elm,))*N  # unsigned long long
def f32f(N: int, elm: float = 0.0): return array('f', (elm,))*N  # float
def f64f(N: int, elm: float = 0.0): return array('d', (elm,))*N  # double
def i8a(init = None):  return array('b') if init is None else array('b', init)  # signed char
def u8a(init = None):  return array('B') if init is None else array('B', init)  # unsigned char
def i16a(init = None): return array('h') if init is None else array('h', init)  # signed short
def u16a(init = None): return array('H') if init is None else array('H', init)  # unsigned short
def i32a(init = None): return array('i') if init is None else array('i', init)  # signed int
def u32a(init = None): return array('I') if init is None else array('I', init)  # unsigned int
def i64a(init = None): return array('q') if init is None else array('q', init)  # signed long long
# def u64a(init = None): return array('Q') if init is None else array('Q', init)  # unsigned long long
def f32a(init = None): return array('f') if init is None else array('f', init)  # float
def f64a(init = None): return array('d') if init is None else array('d', init)  # double
i8_max = (1 << 7)-1
u8_max = (1 << 8)-1
i16_max = (1 << 15)-1
u16_max = (1 << 16)-1
i32_max = (1 << 31)-1
u32_max = (1 << 32)-1
i64_max = (1 << 63)-1
u64_max = (1 << 64)-1
class PacketList(Sequence[tuple[int,int]]):
    def __init__(lst, A: list[int], max1: int):
        lst.A = A
        lst.mask = (1 << (shift := (max1).bit_length())) - 1
        lst.shift = shift
    def __len__(lst): return lst.A.__len__()
    def __contains__(lst, x: tuple[int,int]): return lst.A.__contains__(x[0] << lst.shift | x[1])
    def __getitem__(lst, key) -> tuple[int,int]:
        x = lst.A[key]
        return x >> lst.shift, x & lst.mask
class GraphWeightedBase(GraphBase):
    def __init__(self, N: int, M: int, U: list[int], V: list[int], W: list[int], 
                 deg: list[int], La: list[int], Ra: list[int],
                 Ua: list[int], Va: list[int], Wa: list[int], Ea: list[int], twin: list[int] = None):
        super().__init__(N, M, U, V, deg, La, Ra, Ua, Va, Ea, twin)
        self.W = W
        self.Wa = Wa
        '''Wa[i] lists weights to edges from u for La[u] <= i < Ra[u].'''
        
    def __getitem__(G, u):
        l,r = G.La[u],G.Ra[u]
        return zip(G.Va[l:r], G.Wa[l:r])
    
    @overload
    def distance(G) -> list[list[int]]: ...
    @overload
    def distance(G, s: int = 0) -> list[int]: ...
    @overload
    def distance(G, s: int, g: int) -> int: ...
    def distance(G, s = None, g = None):
        if s == None: return G.floyd_warshall()
        else: return G.dijkstra(s, g)
    def dijkstra(G, s: int, t: int = None):
        N, S, Va, Wa = G.N, G.starts(s), G.Va, G.Wa
        G.back, G.D  = back, D = i32f(N, -1), [inf]*N
        for s in S: D[s] = 0
        que = PriorityQueue(N, S)
        while que:
            u, d = que.pop()
            if d > D[u]: continue
            if u == t: return d
            for i in G.range(u): 
                if chmin(D, v := Va[i], nd := d + Wa[i]):
                    back[v] = i
                    que.push(v, nd)
        return D if t is None else inf 
    def kruskal(G):
        U, V, W, dsu, MST, need = G.U, G.V, G.W, DSU(N := G.N), [0]*(N-1), N-1
        for e in argsort(W):
            u, v = dsu.merge(U[e],V[e],True)
            if u != v:
                MST[need := need-1] = e
                if not need: break
        return None if need else MST
    
    def kruskal_heap(G):
        N, M, U, V, W = G.N, G.M, G.U, G.V, G.W 
        que = PriorityQueue(M, list(range(M)), W)
        dsu = DSU(N)
        MST = [0]*(N-1)
        need = N-1
        while que and need:
            e, _ = que.pop()
            u, v = dsu.merge(U[e],V[e],True)
            if u != v:
                MST[need := need-1] = e
        return None if need else MST
   
    def bellman_ford(G, s: int = 0) -> list[int]:
        Ua, Va, Wa, D = G.Ua, G.Va, G.Wa, [inf]*(N := G.N)
        D[s] = 0
        for _ in range(N-1):
            for i, u in enumerate(Ua):
                if D[u] < inf: chmin(D, Va[i], D[u] + Wa[i])
        return D
    
    def bellman_ford_neg_cyc_check(G, s: int = 0) -> tuple[bool, list[int]]:
        M, U, V, W, D = G.M, G.U, G.V, G.W, G.bellman_ford(s)
        neg_cycle = any(D[U[i]]+W[i]<D[V[i]] for i in range(M) if D[U[i]] < inf)
        return neg_cycle, D
    
    def floyd_warshall(G) -> list[list[int]]:
        N, Ua, Va, Wa = G.N, G.Ua, G.Va, G.Wa
        D = [[inf]*N for _ in range(N)]
        for u in range(N): D[u][u] = 0
        for i in range(len(Ua)): chmin(D[Ua[i]], Va[i], Wa[i])
        for k, Dk in enumerate(D):
            for Di in D:
                if Di[k] >= inf: continue
                for j in range(N):
                    if Dk[j] >= inf: continue
                    chmin(Di, j, Di[k]+Dk[j])
        return D
        
    def floyd_warshall_neg_cyc_check(G):
        D = G.floyd_warshall()
        return any(D[i][i] < 0 for i in range(G.N)), D
    
    @classmethod
    def compile(cls, N: int, M: int, shift: int = -1):
        def parse(ts: TokenStream):
            U, V, W = u32f(M), u32f(M), [0]*M
            for i in range(M):
                u, v, w = ts._line()
                U[i], V[i], W[i] = int(u)+shift, int(v)+shift, int(w)
            return cls(N, U, V, W)
        return parse
class DSU:
    def __init__(self, N):
        self.N = N
        self.par = [-1] * N
    def merge(self, u, v, src = False):
        assert 0 <= u < self.N
        assert 0 <= v < self.N
        x, y = self.leader(u), self.leader(v)
        if x == y: return (x,y) if src else x
        if self.par[x] > self.par[y]:
            x, y = y, x
        self.par[x] += self.par[y]
        self.par[y] = x
        return (x,y) if src else x
    def same(self, u: int, v: int):
        assert 0 <= u < self.N
        assert 0 <= v < self.N
        return self.leader(u) == self.leader(v)
    def leader(self, i) -> int:
        assert 0 <= i < self.N
        par = self.par
        p = par[i]
        while p >= 0:
            if par[p] < 0:
                return p
            par[i], i, p = par[p], par[p], par[par[p]]
        return i
    def size(self, i) -> int:
        assert 0 <= i < self.N
        
        return -self.par[self.leader(i)]
    def groups(self) -> list[list[int]]:
        leader_buf = [self.leader(i) for i in range(self.N)]
        result = [[] for _ in range(self.N)]
        for i in range(self.N):
            result[leader_buf[i]].append(i)
        return [r for r in result if r]
from collections import UserList
from heapq import heapify, heappop, heappush, heappushpop, heapreplace
from typing import Generic
class HeapProtocol(Generic[_T]):
    def pop(self) -> _T: ...
    def push(self, item: _T): ...
    def pushpop(self, item: _T) -> _T: ...
    def replace(self, item: _T) -> _T: ...
class PriorityQueue(HeapProtocol[int], UserList[int]):
    
    def __init__(self, N: int, ids: list[int] = None, priorities: list[int] = None, /):
        self.shift = N.bit_length()
        self.mask = (1 << self.shift)-1
        if ids is None:
            self.data = elist(N)
        elif priorities is None:
            heapify(ids)
            self.data = ids
        else:
            M = len(ids)
            data = [0]*M
            for i in range(M):
                data[i] = self.encode(ids[i], priorities[i]) 
            heapify(data)
            self.data = data
    def encode(self, id, priority):
        return priority << self.shift | id
    
    def decode(self, encoded):
        return self.mask & encoded, encoded >> self.shift
    
    def pop(self):
        return self.decode(heappop(self.data))
    
    def push(self, id: int, priority: int):
        heappush(self.data, self.encode(id, priority))
    def pushpop(self, id: int, priority: int):
        return self.decode(heappushpop(self.data, self.encode(id, priority)))
    
    def replace(self, id: int, priority: int):
        return self.decode(heapreplace(self.data, self.encode(id, priority)))
    
class GraphWeighted(GraphWeightedBase):
    def __init__(G, N: int, U: list[int], V: list[int], W: list[int]):
        Ma, deg = 0, u32f(N)
        for e in range(M := len(U)):
            distinct = (u := U[e]) != (v := V[e])
            deg[u] += 1; deg[v] += distinct; Ma += 1+distinct
        twin, Ea, Ua, Va, Wa = u32f(Ma), u32f(Ma), u32f(Ma), u32f(Ma), [0]*Ma
        
        La, i = u32f(N), 0
        for u,d in enumerate(deg): 
            La[u], i = i, i + d
        Ra = La[:]
        for e in range(M):
            u, v, w = U[e], V[e], W[e]
            i, j = Ra[u], Ra[v]
            Ra[u],Ua[i],Va[i],Wa[i],Ea[i],twin[i] = i+1,u,v,w,e,j
            if i == j: continue # don't add self loops twice
            Ra[v],Ua[j],Va[j],Wa[j],Ea[j],twin[j] = j+1,v,u,w,e,i
        super().__init__(N, M, U, V, W, deg, La, Ra, Ua, Va, Wa, Ea, twin)
from typing import Optional
from typing import Callable, Literal, TypeVar, Union, overload
class TreeBase(GraphBase):
    @overload
    def distance(T) -> list[list[int]]: ...
    @overload
    def distance(T, s: int = 0) -> list[int]: ...
    @overload
    def distance(T, s: int, g: int) -> int: ...
    def distance(T, s = None, g = None):
        if s == None:
            return [T.dfs_distance(u) for u in range(T.N)]
        else:
            return T.dfs_distance(s, g)
    @overload
    def diameter(T) -> int: ...
    @overload
    def diameter(T, endpoints: Literal[True]) -> tuple[int,int,int]: ...
    def diameter(T, endpoints = False):
        mask = (1 << (shift := T.N.bit_length())) - 1
        s = max(d << shift | v for v,d in enumerate(T.distance(0))) & mask
        dg = max(d << shift | v for v,d in enumerate(T.distance(s))) 
        diam, g = dg >> shift, dg & mask
        return (diam, s, g) if endpoints else diam
    
    def dfs_distance(T, s: int, g: Union[int,None] = None):
        st, Va = elist(N := T.N), T.Va
        T.D, T.back = D, back = [inf]*N, i32f(N, -1)
        D[s] = 0
        st.append(s)
        while st:
            nd = D[u := st.pop()]+1
            if u == g: return nd-1
            for i in T.range(u):
                if nd < D[v := Va[i]]:
                    D[v], back[v] = nd, i
                    st.append(v)
        return D if g is None else inf
    def rerooting_dp(T, e: _T, 
                     merge: Callable[[_T,_T],_T], 
                     edge_op: Callable[[_T,int,int,int],_T] = lambda s,i,p,u:s,
                     s: int = 0):
        La, Ua, Va = T.La, T.Ua, T.Va
        order, dp, suf, I = T.dfs_topdown(s), [e]*T.N, [e]*len(Ua), T.Ra[:]
        # up
        for i in order[::-1]:
            u,v = Ua[i], Va[i]
            # subtree v finished up pass, store value to accumulate for u
            dp[v] = new = edge_op(dp[v], i, u, v)
            dp[u] = merge(dp[u], new)
            # suffix accumulation
            if (c:=I[u]-1) > La[u]: suf[c-1] = merge(suf[c], new)
            I[u] = c
        # down
        dp[s] = e # at this point dp stores values to be merged in parent
        for i in order:
            u,v = Ua[i], Va[i]
            dp[u] = merge(pre := dp[u], dp[v])
            dp[v] = edge_op(merge(suf[I[u]], pre), i, v, u)
            I[u] += 1
        return dp
    
    def euler_tour(T, s = 0):
        N, Va = len(T), T.Va
        tin, tout, par, back = [-1]*N,[-1]*N,[-1]*N,[0]*N
        order, delta = elist(2*N), elist(2*N)
        
        st = elist(N); st.append(s)
        while st:
            p = par[u := st.pop()]
            if tin[u] == -1:
                tin[u] = len(order)
                for i in T.range(u):
                    if (v := Va[i]) != p:
                        par[v], back[v] = u, i
                        st.append(u); st.append(v)
                delta.append(1)
            else:
                delta.append(-1)
            
            order.append(u)
            tout[u] = len(order)
        delta[0] = delta[-1] = 0
        T.tin, T.tout, T.par, T.back = tin, tout, par, back
        T.order, T.delta = order, delta
    def hld_precomp(T, r = 0):
        N, time, Va = T.N, 0, T.Va
        tin, tout, size = [0]*N, [0]*N, [1]*N+[0]
        par, heavy, head = [-1]*N, [-1]*N, [r]*N
        depth, order, vis = [0]*N, [0]*N, [0]*N
        st = elist(N)
        st.append(r)
        while st:
            if (s := vis[v := st.pop()]) == 0: # dfs down
                p, vis[v] = par[v], 1; st.append(v)
                for i in T.range(v):
                    if (c := Va[i]) != p:
                        depth[c], par[c] = depth[v]+1, v; st.append(c)
            elif s == 1: # dfs up
                p, l = par[v], -1
                for i in T.range(v):
                    if (c := Va[i]) != p:
                        size[v] += size[c]
                        if size[c] > size[l]:
                            l = c
                heavy[v] = l
                if p == -1:
                    vis[v] = 2
                    st.append(v)
            elif s == 2: # decompose down
                p, h, l = par[v], head[v], heavy[v]
                tin[v], order[time], vis[v] = time, v, 3
                time += 1
                st.append(v)
                
                for i in T.range(v):
                    if (c := Va[i]) != p and c != l:
                        head[c], vis[c] = c, 2
                        st.append(c)
                if l != -1:
                    head[l], vis[l] = h, 2
                    st.append(l)
            elif s == 3: # decompose up
                tout[v] = time
        T.size, T.depth = size, depth
        T.order, T.tin, T.tout = order, tin, tout
        T.par, T.heavy, T.head = par, heavy, head
    @classmethod
    def compile(cls, N: int, shift: int = -1):
        return GraphBase.compile.__func__(cls, N, N-1, shift)
    
class TreeWeightedBase(TreeBase, GraphWeightedBase):
    def dfs_distance(T, s: int, g: Optional[int] = None):
        st, Wa, Va = elist(N := T.N), T.Wa, T.Va
        T.D, T.back = D, back = [inf]*N, i32f(N, -1)
        D[s] = 0; st.append(s)
        while st:
            d = D[u := st.pop()]
            if u == g: return d
            for i in T.range(u):
                if (nd := d+Wa[i]) < D[v := Va[i]]:
                    D[v], back[v] = nd, i; st.append(v)
        return D if g is None else inf 
    
    def euler_tour(T, s = 0):
        N, Va, Wa = len(T), T.Va, T.Wa
        tin, tout, par = [-1]*N,[-1]*N,[-1]*N
        order, delta, Wdelta = elist(2*N), elist(2*N), elist(2*N)
        st, Wst = elist(N), elist(N)
        st.append(s); Wst.append(0)
        while st:
            p, wd = par[u := st.pop()], Wst.pop()
            if tin[u] == -1:
                tin[u] = len(order)
                for i in T.range(u):
                    if (v := Va[i]) != p:
                        w, par[v] = Wa[i], u
                        st.append(u); st.append(v); Wst.append(-w); Wst.append(w)
                delta.append(1)
            else:
                delta.append(-1)
            Wdelta.append(wd); order.append(u)
            tout[u] = len(order)
        delta[0] = delta[-1] = 0
        T.tin, T.tout, T.par = tin, tout, par
        T.order, T.delta, T.Wdelta = order, delta, Wdelta
    def hld_precomp(T, r = 0):
        N, time, Va, Wa = T.N, 0, T.Va, T.Wa
        tin, tout, size = [0]*N, [0]*N, [1]*N+[0]
        par, heavy, head = [-1]*N, [-1]*N, [r]*N
        depth, order, vis = [0]*N, [0]*N, [0]*N
        Wpar = [0]*N
        st = elist(N)
        st.append(r)
        while st:
            if (s := vis[v := st.pop()]) == 0: # dfs down
                p, vis[v] = par[v], 1
                st.append(v)
                for i in T.range(v):
                    if (c := Va[i]) != p:
                        depth[c], par[c], Wpar[c] = depth[v]+1, v, Wa[i]
                        st.append(c)
            elif s == 1: # dfs up
                p, l = par[v], -1
                for i in T.range(v):
                    if (c := Va[i]) != p:
                        size[v] += size[c]
                        if size[c] > size[l]:
                            l = c
                heavy[v] = l
                if p == -1:
                    vis[v] = 2
                    st.append(v)
            elif s == 2: # decompose down
                p, h, l = par[v], head[v], heavy[v]
                tin[v], order[time], vis[v] = time, v, 3
                time += 1
                st.append(v)
                
                for i in T.range(v):
                    if (c := Va[i]) != p and c != l:
                        head[c], vis[c] = c, 2
                        st.append(c)
                if l != -1:
                    head[l], vis[l] = h, 2
                    st.append(l)
            elif s == 3: # decompose up
                tout[v] = time
        T.size, T.depth = size, depth
        T.order, T.tin, T.tout = order, tin, tout
        T.par, T.heavy, T.head = par, heavy, head
        T.Wpar = Wpar
    @classmethod
    def compile(cls, N: int, shift: int = -1):
        return GraphWeightedBase.compile.__func__(cls, N, N-1, shift)
    
class TreeWeighted(TreeWeightedBase, GraphWeighted):
    pass
class AuxTreeBase(TreeWeightedBase):
    def __init__(T, lca: LCATable):
        T.lca = lca
        T.Vset = elist(T.N)
        T.post = elist(T.N-1)
        T.Ra = T.La[:]
    def add(T, u, v):
        w = T.lca.distance(u,v)
        i, j = T.Ra[u], T.Ra[v]
        T.Ua[i], T.Va[i], T.Wa[i], T.twin[i], T.Ra[u] = u, v, w, j, i+1
        if i == j: return j
        T.Ua[j], T.Va[j], T.Wa[j], T.twin[j], T.Ra[v] = v, u, w, i, j+1
        return j
    def tree(T, U: list[int], sort=True):
        if sort: U = sorted(U, key = T.tin.__getitem__)
        st = T.prep_st()
        lca, tin, V, post = T.lca, T.tin, T.Vset, T.post
        # reset
        while V:
            T.Ra[u] = T.La[u := V.pop()]
            if T.vis: T.vis[u] = 0
        post.clear()
        st.append(U[0])
        for j in range(len(U)-1):
            u, v = U[j], U[j+1]
            a, _ = lca.query(u, v)
            if a != u:
                l = st.pop()
                while st and tin[t := st[-1]] > tin[a]:
                    V.append(l); post.append(T.add(l, l := st.pop()))
                if not st or t != a: st.append(a)
                V.append(l); post.append(T.add(l, a))
            st.append(v)
        l = st.pop()
        while st: V.append(l); post.append(T.add(l, l := st.pop()))
        V.append(l)
        return V, post
    def trees(T, C: list[int]):
        lca, N = T.lca, T.N
        T.Ra, cnt, order = T.La[:], [0]*N, argsort(T.tin)
        for c in C: cnt[c] += 1
        L = [0]*N
        for i in range(N-1): L[i+1] = L[i]+cnt[i]
        R, G = L[:], [0]*N
        
        for i in order: c = C[i]; G[R[c]] = i; R[c] += 1
        st, V, post = elist(N), elist(N), elist(N)
        La, Ra, tin = T.La, T.Ra, T.tin
        for c in range(N):
            l, r = L[c], R[c]
            if l == r: continue
            st.append(G[l])
            for j in range(l,r-1):
                u, v = G[j], G[j+1]
                a, _ = lca.query(u, v)
                if a != u:
                    l = st.pop()
                    while st and tin[t := st[-1]] > tin[a]:
                        V.append(l); post.append(T.add(l, l := st.pop()))
                    if not st or t != a: st.append(a)
                    V.append(l); post.append(T.add(l, a))
                st.append(v)
            l = st.pop()
            while st: V.append(l); post.append(T.add(l, l := st.pop()))
            V.append(l)
            yield c, V, post
            while V:
                Ra[u] = La[u := V.pop()]
                if T.vis: T.vis[u] = 0
            post.clear()
    def rerooting_dp(T, C: list[int], e: _T, 
                     merge: Callable[[_T,_T],_T], 
                     edge_op: Callable[[_T,int,int,int,int],_T] = lambda s,i,p,u,c:s):
        ans, dp, suf, I = [e]*T.N, [e]*T.N, [e]*len(T.Ua), T.La[:]
        for c, V, post in T.trees(C):
            r = V[-1]
            for v in V: I[v] = T.Ra[v]
            # up
            for i in post:
                u,v = T.Ua[i], T.Va[i]
                # subtree v finished up pass, store value to accumulate for u
                dp[v] = new = edge_op(dp[v], i, u, v, c)
                dp[u] = merge(dp[u], new)
                # suffix accumulation
                if (j:=I[u]-1) > T.La[u]: suf[j-1] = merge(suf[j], new)
                I[u] = j
            # down
            dp[r] = e # at this point dp stores values to be merged in parent
            for i in reversed(post):
                u,v = T.Ua[i], T.Va[i]
                dp[u] = merge(pre := dp[u], dp[v])
                dp[v] = edge_op(merge(suf[I[u]], pre), i, v, u, c)
                I[u] += 1
            
            # store ans and reset
            for v in V:
                if C[v] == c: ans[v] = dp[v]
                dp[v] = e
            for i in post:
                suf[i] = e
        return ans
class AuxTreeWeighted(AuxTreeBase, TreeWeighted):
    def __init__(T, N, U, V, W, root=0):
        TreeWeighted.__init__(T, N, U, V, W)
        AuxTreeBase.__init__(T, LCATableWeighted(T, root))
    @classmethod
    def compile(cls, N: int, shift: int = -1, root=0):
        M = N-1
        def parse(ts: TokenStream):
            U, V, W = u32f(M), u32f(M), [0]*M
            for i in range(M):
                u, v, w = ts._line()
                U[i], V[i], W[i] = int(u)+shift, int(v)+shift, int(w)
            return cls(N, U, V, W, root)
        return parse
from typing import Iterable, Type, Union, overload
@overload
def read() -> Iterable[int]: ...
@overload
def read(spec: int) -> list[int]: ...
@overload
def read(spec: Union[Type[_T],_T], char=False) -> _T: ...
def read(spec: Union[Type[_T],_T] = None, char=False):
    if not char and spec is None: return map(int, TokenStream.default.line())
    parser: _T = Parser.compile(spec)
    return parser(CharStream.default if char else TokenStream.default)
def write(*args, **kwargs):
    '''Prints the values to a stream, or to stdout_fast by default.'''
    sep, file = kwargs.pop("sep", " "), kwargs.pop("file", IOWrapper.stdout)
    at_start = True
    for x in args:
        if not at_start:
            file.write(sep)
        file.write(str(x))
        at_start = False
    file.write(kwargs.pop("end", "\n"))
    if kwargs.pop("flush", False):
        file.flush()
if __name__ == '__main__':
    main()
            
            
            
        