結果
問題 | No.901 K-ary εxtrεεmε |
ユーザー | StanMarsh |
提出日時 | 2024-02-27 00:18:53 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,066 ms / 3,000 ms |
コード長 | 15,219 bytes |
コンパイル時間 | 309 ms |
コンパイル使用メモリ | 81,828 KB |
実行使用メモリ | 234,224 KB |
最終ジャッジ日時 | 2024-02-27 00:19:23 |
合計ジャッジ時間 | 28,490 ms |
ジャッジサーバーID (参考情報) |
judge13 / judge11 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 732 ms
234,224 KB |
testcase_01 | AC | 153 ms
91,324 KB |
testcase_02 | AC | 218 ms
92,732 KB |
testcase_03 | AC | 216 ms
92,732 KB |
testcase_04 | AC | 210 ms
92,860 KB |
testcase_05 | AC | 211 ms
92,732 KB |
testcase_06 | AC | 213 ms
92,604 KB |
testcase_07 | AC | 1,022 ms
209,660 KB |
testcase_08 | AC | 1,023 ms
208,888 KB |
testcase_09 | AC | 988 ms
208,884 KB |
testcase_10 | AC | 990 ms
209,244 KB |
testcase_11 | AC | 1,013 ms
209,616 KB |
testcase_12 | AC | 1,030 ms
209,636 KB |
testcase_13 | AC | 1,028 ms
210,900 KB |
testcase_14 | AC | 1,038 ms
210,888 KB |
testcase_15 | AC | 1,014 ms
209,620 KB |
testcase_16 | AC | 1,005 ms
208,852 KB |
testcase_17 | AC | 1,050 ms
209,724 KB |
testcase_18 | AC | 1,034 ms
211,504 KB |
testcase_19 | AC | 1,012 ms
210,468 KB |
testcase_20 | AC | 1,015 ms
209,712 KB |
testcase_21 | AC | 1,054 ms
211,504 KB |
testcase_22 | AC | 1,066 ms
219,232 KB |
testcase_23 | AC | 1,048 ms
220,612 KB |
testcase_24 | AC | 1,045 ms
221,276 KB |
testcase_25 | AC | 1,057 ms
221,008 KB |
testcase_26 | AC | 1,049 ms
220,616 KB |
testcase_27 | AC | 920 ms
209,692 KB |
testcase_28 | AC | 945 ms
211,480 KB |
testcase_29 | AC | 948 ms
210,964 KB |
ソースコード
from random import getrandbits, randrange from string import ascii_lowercase, ascii_uppercase import sys from math import ceil, floor, sqrt, pi, factorial, gcd, log, log10, log2, inf, cos, sin from copy import deepcopy, copy from collections import Counter, deque, defaultdict from heapq import heapify, heappop, heappush from itertools import ( accumulate, chain, product, combinations, combinations_with_replacement, permutations, ) from bisect import bisect, bisect_left, bisect_right from functools import lru_cache, reduce from decimal import Decimal, getcontext from typing import List, Tuple, Optional class Inf: def __init__(self, value): self.value = value def __lt__(self, other): return False def __le__(self, other): if isinstance(other, Inf): return True return False def __gt__(self, other): if isinstance(other, Inf): return False return True def __ge__(self, other): return True def __eq__(self, other): return isinstance(other, Inf) and self.value == other.value def __repr__(self): return f"{self.value}" def __add__(self, other): return Inf(self.value) if isinstance(other, Inf) else self def __sub__(self, other): return Inf(self.value) if isinstance(other, Inf) else self def __mul__(self, other): return Inf(self.value) if isinstance(other, Inf) else self def ceil_div(a, b): return (a + b - 1) // b def isqrt(num): res = int(sqrt(num)) while res * res > num: res -= 1 while (res + 1) * (res + 1) <= num: res += 1 return res def int1(s): return int(s) - 1 from types import GeneratorType def bootstrap(f, stack=[]): def wrapped(*args, **kwargs): if stack: return f(*args, **kwargs) else: to = f(*args, **kwargs) while True: if type(to) is GeneratorType: stack.append(to) to = next(to) else: stack.pop() if not stack: break to = stack[-1].send(to) return to return wrapped import sys import os from io import BytesIO, IOBase BUFSIZE = 8192 class FastIO(IOBase): newlines = 0 def __init__(self, file): self._fd = file.fileno() self.buffer = BytesIO() self.writable = "x" in file.mode or "r" not in file.mode self.write = self.buffer.write if self.writable else None def read(self): while True: b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE)) if not b: break ptr = self.buffer.tell() self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr) self.newlines = 0 return self.buffer.read() def readline(self): while self.newlines == 0: b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE)) self.newlines = b.count(b"\n") + (not b) ptr = self.buffer.tell() self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr) self.newlines -= 1 return self.buffer.readline() def flush(self): if self.writable: os.write(self._fd, self.buffer.getvalue()) self.buffer.truncate(0), self.buffer.seek(0) class IOWrapper(IOBase): def __init__(self, file): self.buffer = FastIO(file) self.flush = self.buffer.flush self.writable = self.buffer.writable self.write = lambda s: self.buffer.write(s.encode("ascii")) self.read = lambda: self.buffer.read().decode("ascii") self.readline = lambda: self.buffer.readline().decode("ascii") sys.stdin, sys.stdout = IOWrapper(sys.stdin), IOWrapper(sys.stdout) input = lambda: sys.stdin.readline().rstrip("\r\n") print = lambda *args, end="\n", sep=" ": sys.stdout.write( sep.join(map(str, args)) + end ) def II(): return int(input()) def MII(base=0): return map(lambda s: int(s) - base, input().split()) def LII(base=0): return list(MII(base)) def NA(): n = II() a = LII() return n, a def read_graph(n, m, base=0, directed=False, return_edges=False): g = [[] for _ in range(n)] edges = [] for _ in range(m): a, b = MII(base) if return_edges: edges.append((a, b)) g[a].append(b) if not directed: g[b].append(a) if return_edges: return g, edges return g def read_graph_with_weight(n, m, base=0, directed=False, return_edges=False): g = [[] for _ in range(n)] edges = [] for _ in range(m): a, b, w = MII() a, b = a - base, b - base if return_edges: edges.append((a, b, w)) g[a].append((b, w)) if not directed: g[b].append((a, w)) if return_edges: return g, edges return g def read_edges_from_ps(): ps = LII(1) edges = [] for i, p in enumerate(ps, 1): edges.append((p, i)) return edges def yes(res): print("Yes" if res else "No") def YES(res): print("YES" if res else "NO") def cmin(dp, i, x): if x < dp[i]: dp[i] = x def cmax(dp, i, x): if x > dp[i]: dp[i] = x def alp_a_to_i(s): return ord(s) - ord("a") def alp_A_to_i(s): return ord(s) - ord("A") def alp_i_to_a(i): return chr(ord("a") + i) def alp_i_to_A(i): return chr(ord("A") + i) d4 = [(1, 0), (0, 1), (-1, 0), (0, -1)] d8 = [(1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0), (-1, -1), (0, -1), (1, -1)] def ranges(n, m): return ((i, j) for i in range(n) for j in range(m)) def rangess(a, b, c): return ((i, j, k) for i in range(a) for j in range(b) for k in range(c)) def valid(i, j, n, m): return 0 <= i < n and 0 <= j < m def ninj(i, j, n, m): return [(i + di, j + dj) for di, dj in d4 if valid(i + di, j + dj, n, m)] def gen(x, *args): if len(args) == 1: return [x] * args[0] if len(args) == 2: return [[x] * args[1] for _ in [0] * args[0]] if len(args) == 3: return [[[x] * args[2] for _ in [0] * args[1]] for _ in [0] * args[0]] if len(args) == 4: return [ [[[x] * args[3] for _ in [0] * args[2]] for _ in [0] * args[1]] for _ in [0] * args[0] ] list2d = lambda a, b, v: [[v] * b for _ in range(a)] list3d = lambda a, b, c, v: [[[v] * c for _ in range(b)] for _ in range(a)] class Debug: def __init__(self, debug=False): self.debug = debug cur_path = os.path.dirname(os.path.abspath(__file__)) self.local = os.path.exists(cur_path + "/.cph") def get_ic(self): if self.debug and self.local: from icecream import ic return ic else: return lambda *args, **kwargs: ... def pairwise(a): n = len(a) for i in range(n - 1): yield a[i], a[i + 1] def factorial(n): return reduce(lambda x, y: x * y, range(1, n + 1)) ic = Debug(1).get_ic() inf = Inf(-1) class PrefixSum: def __init__(self, a): self.n = len(a) self.sum = [0] * (self.n + 1) for i in range(1, self.n + 1): self.sum[i] = self.sum[i - 1] + a[i - 1] def __getitem__(self, key): if isinstance(key, slice): start = key.start if key.start is not None else 0 stop = key.stop if key.stop is not None else self.n - 1 return self.get_sum(start, stop) return self.sum[key + 1] def __iter__(self): return iter(self.sum) def __len__(self): return self.n def get_sum(self, l, r): if l > r: return 0 return self.sum[r + 1] - self.sum[l] def __repr__(self): return str(self.sum) class SparseTable: def __init__(self, data: list, func=min): self.func = func self.st = st = [list(data)] i, N = 1, len(st[0]) while 2 * i <= N + 1: qz = st[-1] st.append([func(qz[j], qz[j + i]) for j in range(N - 2 * i + 1)]) i <<= 1 def query(self, begin: int, end: int): lg = (end - begin + 1).bit_length() - 1 return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1]) class EulerTour: def __init__(self, vertex_num: int): self.N = vertex_num self.edge: list[tuple[int, int]] = [] self.G: list[list[tuple[int, int]]] = [[] for _ in range(vertex_num)] self.lca = None def add_edge(self, u: int, v: int): eid = len(self.edge) self.edge.append((u, v)) self.G[u].append((v, eid)) self.G[v].append((u, eid)) def build(self, root=0): N, G = self.N, self.G self.root = root tour = [] depth = [] node_in = [-1] * N node_out = [-1] * N node_depth = [-1] * N edge_in = [-1] * (N - 1) edge_out = [-1] * (N - 1) parent = [-1] * N stk = [(root, 0, -1)] t = -1 while stk: t += 1 v, d, ei = stk.pop() if node_in[v] < 0: node_in[v] = t node_depth[v] = d if ei >= 0: edge_in[ei] = t tour.append(v) depth.append(d) is_leaf = True for nv, ne in G[v]: if parent[nv] >= 0: continue parent[nv] = v stk.append((v, d, ne)) stk.append((nv, d + 1, ne)) is_leaf = False if is_leaf: node_out[v] = t + 1 if ei >= 0: edge_out[ei] = t + 1 else: node_out[v] = t + 1 edge_out[ei] = t + 1 tour.append(v) depth.append(d) tour.append(-1) depth.append(-1) self.tour = tour self.depth = depth self.node_in = node_in self.node_out = node_out self.node_depth = node_depth self.edge_in = edge_in self.edge_out = edge_out self.parent = parent class __LCA: def __init__(self, tour: list[int], tour_depth: list[int], node_in: list[int]): data = [(d, v) for d, v in zip(tour_depth, tour)] self._st = SparseTable(data, func=lambda x, y: x if x <= y else y) self._node_in = node_in def get(self, u, v): node_in = self._node_in l, r = node_in[u], node_in[v] if l > r: l, r = r, l return self._st.query(l, r)[1] def use_LCA(self): if self.lca is None: self.lca = self.__LCA(self.tour, self.depth, self.node_in) return self.lca class __AuxiliaryTree(dict): def __init__( self, vertex_group_id: list[int], special_nodes: list[int], node_in: list[int], node_out: list[int], lca, parent, ): V: dict[int, list[int]] = dict() if not ((vertex_group_id is None) ^ (special_nodes is None)): raise ValueError if vertex_group_id is not None: for v, g in enumerate(vertex_group_id): if g not in V: V[g] = [] V[g].append(v) if special_nodes is not None: V[1] = special_nodes[::] for k, vv in V.items(): vv.sort(key=lambda v: node_in[v]) for i in range(1, len(vv)): vv.append(lca.get(vv[i - 1], vv[i])) vv = sorted(set(vv), key=lambda v: node_in[v]) G: dict[int, list[int]] = dict() P: dict[int, int] = dict() stk: list[int] = [] for v in vv: while stk and node_out[stk[-1]] <= node_out[v]: stk.pop() if stk: p = stk[-1] if p not in G: G[p] = [] G[p].append(v) P[v] = p stk.append(v) self[k] = (G, P, vv) def use_AuxiliaryTree( self, vertex_group_id: list[int] = None, special_nodes: list[int] = None ) -> dict[int, tuple[dict[int, list[int]], dict[int, int], list[int]]]: return self.__AuxiliaryTree( vertex_group_id, special_nodes, self.node_in, self.node_out, self.use_LCA(), self.parent, ) def MyEulerTour(n, G, i0=0): P = [-1] * n stack = [~i0, i0] ct = -1 depth = [-1] * n nodein = [-1] * n nodeout = [-1] * n ET = [] de = 0 while stack: i = stack.pop() if i < 0: ET.append(P[~i]) nodeout[~i] = ct de -= 1 continue if i >= 0: ct += 1 ET.append(i) if nodein[i] == -1: nodein[i] = ct depth[i] = de de += 1 for v, _ in G[i][::-1]: if v == P[i]: continue P[v] = i stack.append(~v) stack.append(v) return ET, nodein, nodeout, depth class Tree: def __init__(self, g=None, edges=None, root=0, vals=[]): if edges is not None: self.n = n = len(edges) + 1 self.g = g = [[] for _ in range(n)] for u, v in edges: self.g[u].append(v) self.g[v].append(u) else: self.n = n = len(g) self.g = g self.root = root self.parent = parent = [-1] * n stk = [root] self.order = order = [root] self.depth = depth = [0] * n while stk: u = stk.pop() for v in g[u]: if v != root and parent[v] == -1: depth[v] = depth[u] + 1 parent[v] = u stk.append(v) order.append(v) n = II() T = EulerTour(n) g, edges = read_graph_with_weight(n, n - 1, 0, return_edges=True) for u, v, w in edges: T.add_edge(u, v) T.build() lca = T.use_LCA() t = Tree(edges=[(u, v) for u, v, _ in edges]) depth, parent = t.depth, t.parent a = [0] * n for u, v, w in edges: child = u if depth[u] > depth[v] else v a[child] = w b = a[::] vst = [0] * n for u in T.tour: if u < 0 or vst[u] == 1: continue vst[u] = 1 p = parent[u] if p != -1: b[u] += b[p] def query(u, v): return b[u] + b[v] - 2 * b[lca.get(u, v)] for _ in range(II()): a = LII()[1:] res = 0 for ai, (children, ps, tour) in T.use_AuxiliaryTree(special_nodes=a).items(): for u in tour[::-1]: for v in children.get(u, []): res += query(u, v) print(res)