結果
問題 | No.901 K-ary εxtrεεmε |
ユーザー | StanMarsh |
提出日時 | 2024-02-27 00:27:36 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 943 ms / 3,000 ms |
コード長 | 14,449 bytes |
コンパイル時間 | 248 ms |
コンパイル使用メモリ | 81,828 KB |
実行使用メモリ | 231,924 KB |
最終ジャッジ日時 | 2024-02-27 00:28:02 |
合計ジャッジ時間 | 25,404 ms |
ジャッジサーバーID (参考情報) |
judge12 / judge14 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 673 ms
231,924 KB |
testcase_01 | AC | 146 ms
91,464 KB |
testcase_02 | AC | 235 ms
92,744 KB |
testcase_03 | AC | 203 ms
92,744 KB |
testcase_04 | AC | 200 ms
92,872 KB |
testcase_05 | AC | 200 ms
92,616 KB |
testcase_06 | AC | 204 ms
92,616 KB |
testcase_07 | AC | 936 ms
210,932 KB |
testcase_08 | AC | 885 ms
210,044 KB |
testcase_09 | AC | 894 ms
210,012 KB |
testcase_10 | AC | 867 ms
209,624 KB |
testcase_11 | AC | 903 ms
209,632 KB |
testcase_12 | AC | 905 ms
210,776 KB |
testcase_13 | AC | 907 ms
211,412 KB |
testcase_14 | AC | 921 ms
210,392 KB |
testcase_15 | AC | 923 ms
210,896 KB |
testcase_16 | AC | 881 ms
209,468 KB |
testcase_17 | AC | 902 ms
210,872 KB |
testcase_18 | AC | 897 ms
210,476 KB |
testcase_19 | AC | 898 ms
211,364 KB |
testcase_20 | AC | 931 ms
210,476 KB |
testcase_21 | AC | 918 ms
211,760 KB |
testcase_22 | AC | 927 ms
220,884 KB |
testcase_23 | AC | 943 ms
220,868 KB |
testcase_24 | AC | 909 ms
221,380 KB |
testcase_25 | AC | 915 ms
218,676 KB |
testcase_26 | AC | 916 ms
217,540 KB |
testcase_27 | AC | 806 ms
210,964 KB |
testcase_28 | AC | 821 ms
210,448 KB |
testcase_29 | AC | 799 ms
211,828 KB |
ソースコード
from random import getrandbits, randrange from string import ascii_lowercase, ascii_uppercase import sys from math import ceil, floor, sqrt, pi, factorial, gcd, log, log10, log2, inf, cos, sin from copy import deepcopy, copy from collections import Counter, deque, defaultdict from heapq import heapify, heappop, heappush from itertools import ( accumulate, chain, product, combinations, combinations_with_replacement, permutations, ) from bisect import bisect, bisect_left, bisect_right from functools import lru_cache, reduce from decimal import Decimal, getcontext from typing import List, Tuple, Optional class Inf: def __init__(self, value): self.value = value def __lt__(self, other): return False def __le__(self, other): if isinstance(other, Inf): return True return False def __gt__(self, other): if isinstance(other, Inf): return False return True def __ge__(self, other): return True def __eq__(self, other): return isinstance(other, Inf) and self.value == other.value def __repr__(self): return f"{self.value}" def __add__(self, other): return Inf(self.value) if isinstance(other, Inf) else self def __sub__(self, other): return Inf(self.value) if isinstance(other, Inf) else self def __mul__(self, other): return Inf(self.value) if isinstance(other, Inf) else self def ceil_div(a, b): return (a + b - 1) // b def isqrt(num): res = int(sqrt(num)) while res * res > num: res -= 1 while (res + 1) * (res + 1) <= num: res += 1 return res def int1(s): return int(s) - 1 from types import GeneratorType def bootstrap(f, stack=[]): def wrapped(*args, **kwargs): if stack: return f(*args, **kwargs) else: to = f(*args, **kwargs) while True: if type(to) is GeneratorType: stack.append(to) to = next(to) else: stack.pop() if not stack: break to = stack[-1].send(to) return to return wrapped import sys import os from io import BytesIO, IOBase BUFSIZE = 8192 class FastIO(IOBase): newlines = 0 def __init__(self, file): self._fd = file.fileno() self.buffer = BytesIO() self.writable = "x" in file.mode or "r" not in file.mode self.write = self.buffer.write if self.writable else None def read(self): while True: b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE)) if not b: break ptr = self.buffer.tell() self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr) self.newlines = 0 return self.buffer.read() def readline(self): while self.newlines == 0: b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE)) self.newlines = b.count(b"\n") + (not b) ptr = self.buffer.tell() self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr) self.newlines -= 1 return self.buffer.readline() def flush(self): if self.writable: os.write(self._fd, self.buffer.getvalue()) self.buffer.truncate(0), self.buffer.seek(0) class IOWrapper(IOBase): def __init__(self, file): self.buffer = FastIO(file) self.flush = self.buffer.flush self.writable = self.buffer.writable self.write = lambda s: self.buffer.write(s.encode("ascii")) self.read = lambda: self.buffer.read().decode("ascii") self.readline = lambda: self.buffer.readline().decode("ascii") sys.stdin, sys.stdout = IOWrapper(sys.stdin), IOWrapper(sys.stdout) input = lambda: sys.stdin.readline().rstrip("\r\n") print = lambda *args, end="\n", sep=" ": sys.stdout.write( sep.join(map(str, args)) + end ) def II(): return int(input()) def MII(base=0): return map(lambda s: int(s) - base, input().split()) def LII(base=0): return list(MII(base)) def NA(): n = II() a = LII() return n, a def read_graph(n, m, base=0, directed=False, return_edges=False): g = [[] for _ in range(n)] edges = [] for _ in range(m): a, b = MII(base) if return_edges: edges.append((a, b)) g[a].append(b) if not directed: g[b].append(a) if return_edges: return g, edges return g def read_graph_with_weight(n, m, base=0, directed=False, return_edges=False): g = [[] for _ in range(n)] edges = [] for _ in range(m): a, b, w = MII() a, b = a - base, b - base if return_edges: edges.append((a, b, w)) g[a].append((b, w)) if not directed: g[b].append((a, w)) if return_edges: return g, edges return g def read_edges_from_ps(): ps = LII(1) edges = [] for i, p in enumerate(ps, 1): edges.append((p, i)) return edges def yes(res): print("Yes" if res else "No") def YES(res): print("YES" if res else "NO") def cmin(dp, i, x): if x < dp[i]: dp[i] = x def cmax(dp, i, x): if x > dp[i]: dp[i] = x def alp_a_to_i(s): return ord(s) - ord("a") def alp_A_to_i(s): return ord(s) - ord("A") def alp_i_to_a(i): return chr(ord("a") + i) def alp_i_to_A(i): return chr(ord("A") + i) d4 = [(1, 0), (0, 1), (-1, 0), (0, -1)] d8 = [(1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0), (-1, -1), (0, -1), (1, -1)] def ranges(n, m): return ((i, j) for i in range(n) for j in range(m)) def rangess(a, b, c): return ((i, j, k) for i in range(a) for j in range(b) for k in range(c)) def valid(i, j, n, m): return 0 <= i < n and 0 <= j < m def ninj(i, j, n, m): return [(i + di, j + dj) for di, dj in d4 if valid(i + di, j + dj, n, m)] def gen(x, *args): if len(args) == 1: return [x] * args[0] if len(args) == 2: return [[x] * args[1] for _ in [0] * args[0]] if len(args) == 3: return [[[x] * args[2] for _ in [0] * args[1]] for _ in [0] * args[0]] if len(args) == 4: return [ [[[x] * args[3] for _ in [0] * args[2]] for _ in [0] * args[1]] for _ in [0] * args[0] ] list2d = lambda a, b, v: [[v] * b for _ in range(a)] list3d = lambda a, b, c, v: [[[v] * c for _ in range(b)] for _ in range(a)] class Debug: def __init__(self, debug=False): self.debug = debug cur_path = os.path.dirname(os.path.abspath(__file__)) self.local = os.path.exists(cur_path + "/.cph") def get_ic(self): if self.debug and self.local: from icecream import ic return ic else: return lambda *args, **kwargs: ... def pairwise(a): n = len(a) for i in range(n - 1): yield a[i], a[i + 1] def factorial(n): return reduce(lambda x, y: x * y, range(1, n + 1)) ic = Debug(1).get_ic() inf = Inf(-1) class PrefixSum: def __init__(self, a): self.n = len(a) self.sum = [0] * (self.n + 1) for i in range(1, self.n + 1): self.sum[i] = self.sum[i - 1] + a[i - 1] def __getitem__(self, key): if isinstance(key, slice): start = key.start if key.start is not None else 0 stop = key.stop if key.stop is not None else self.n - 1 return self.get_sum(start, stop) return self.sum[key + 1] def __iter__(self): return iter(self.sum) def __len__(self): return self.n def get_sum(self, l, r): if l > r: return 0 return self.sum[r + 1] - self.sum[l] def __repr__(self): return str(self.sum) class SparseTable: def __init__(self, data: list, func=min): self.func = func self.st = st = [list(data)] i, N = 1, len(st[0]) while 2 * i <= N + 1: qz = st[-1] st.append([func(qz[j], qz[j + i]) for j in range(N - 2 * i + 1)]) i <<= 1 def query(self, begin: int, end: int): lg = (end - begin + 1).bit_length() - 1 return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1]) class EulerTour: def __init__(self, vertex_num: int): self.N = vertex_num self.edge: list[tuple[int, int]] = [] self.G: list[list[tuple[int, int]]] = [[] for _ in range(vertex_num)] self.lca = None def add_edge(self, u: int, v: int): eid = len(self.edge) self.edge.append((u, v)) self.G[u].append((v, eid)) self.G[v].append((u, eid)) def build(self, root=0): N, G = self.N, self.G self.root = root tour = [] depth = [] node_in = [-1] * N node_out = [-1] * N node_depth = [-1] * N edge_in = [-1] * (N - 1) edge_out = [-1] * (N - 1) parent = [-1] * N stk = [(root, 0, -1)] t = -1 while stk: t += 1 v, d, ei = stk.pop() if node_in[v] < 0: node_in[v] = t node_depth[v] = d if ei >= 0: edge_in[ei] = t tour.append(v) depth.append(d) is_leaf = True for nv, ne in G[v]: if parent[nv] >= 0: continue parent[nv] = v stk.append((v, d, ne)) stk.append((nv, d + 1, ne)) is_leaf = False if is_leaf: node_out[v] = t + 1 if ei >= 0: edge_out[ei] = t + 1 else: node_out[v] = t + 1 edge_out[ei] = t + 1 tour.append(v) depth.append(d) tour.append(-1) depth.append(-1) self.tour = tour self.depth = depth self.node_in = node_in self.node_out = node_out self.node_depth = node_depth self.edge_in = edge_in self.edge_out = edge_out self.parent = parent class __LCA: def __init__(self, tour: list[int], tour_depth: list[int], node_in: list[int]): data = [(d, v) for d, v in zip(tour_depth, tour)] self._st = SparseTable(data, func=lambda x, y: x if x <= y else y) self._node_in = node_in def get(self, u, v): node_in = self._node_in l, r = node_in[u], node_in[v] if l > r: l, r = r, l return self._st.query(l, r)[1] def use_LCA(self): if self.lca is None: self.lca = self.__LCA(self.tour, self.depth, self.node_in) return self.lca class __AuxiliaryTree(dict): def __init__( self, vertex_group_id: list[int], special_nodes: list[int], node_in: list[int], node_out: list[int], lca, parent, ): V: dict[int, list[int]] = dict() if not ((vertex_group_id is None) ^ (special_nodes is None)): raise ValueError if vertex_group_id is not None: for v, g in enumerate(vertex_group_id): if g not in V: V[g] = [] V[g].append(v) if special_nodes is not None: V[1] = special_nodes[::] for k, vv in V.items(): vv.sort(key=lambda v: node_in[v]) for i in range(1, len(vv)): vv.append(lca.get(vv[i - 1], vv[i])) vv = sorted(set(vv), key=lambda v: node_in[v]) G: dict[int, list[int]] = dict() P: dict[int, int] = dict() stk: list[int] = [] for v in vv: while stk and node_out[stk[-1]] <= node_out[v]: stk.pop() if stk: p = stk[-1] if p not in G: G[p] = [] G[p].append(v) P[v] = p stk.append(v) self[k] = (G, P, vv) def use_AuxiliaryTree( self, vertex_group_id: list[int] = None, special_nodes: list[int] = None ) -> dict[int, tuple[dict[int, list[int]], dict[int, int], list[int]]]: return self.__AuxiliaryTree( vertex_group_id, special_nodes, self.node_in, self.node_out, self.use_LCA(), self.parent, ) class Tree: def __init__(self, g=None, edges=None, root=0, vals=[]): if edges is not None: self.n = n = len(edges) + 1 self.g = g = [[] for _ in range(n)] for u, v in edges: self.g[u].append(v) self.g[v].append(u) else: self.n = n = len(g) self.g = g self.root = root self.parent = parent = [-1] * n stk = [root] self.order = order = [root] self.depth = depth = [0] * n while stk: u = stk.pop() for v in g[u]: if v != root and parent[v] == -1: depth[v] = depth[u] + 1 parent[v] = u stk.append(v) order.append(v) n = II() T = EulerTour(n) g, edges = read_graph_with_weight(n, n - 1, 0, return_edges=True) for u, v, w in edges: T.add_edge(u, v) T.build() lca = T.use_LCA() t = Tree(edges=[(u, v) for u, v, _ in edges]) order, depth, parent = t.order, t.depth, t.parent a = [0] * n for u, v, w in edges: child = u if depth[u] > depth[v] else v a[child] = w b = a[::] vst = [0] * n for u in order[1:]: p = parent[u] b[u] += b[p] def query(u, v): return b[u] + b[v] - 2 * b[lca.get(u, v)] for _ in range(II()): a = LII()[1:] res = 0 for ai, (children, ps, tour) in T.use_AuxiliaryTree(special_nodes=a).items(): for u in tour[::-1]: for v in children.get(u, []): res += query(u, v) print(res)