結果
問題 | No.901 K-ary εxtrεεmε |
ユーザー | StanMarsh |
提出日時 | 2024-02-27 00:27:36 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 885 ms / 3,000 ms |
コード長 | 14,449 bytes |
コンパイル時間 | 303 ms |
コンパイル使用メモリ | 82,328 KB |
実行使用メモリ | 233,828 KB |
最終ジャッジ日時 | 2024-09-29 11:51:43 |
合計ジャッジ時間 | 23,974 ms |
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 634 ms
233,828 KB |
testcase_01 | AC | 131 ms
92,648 KB |
testcase_02 | AC | 201 ms
93,632 KB |
testcase_03 | AC | 184 ms
93,612 KB |
testcase_04 | AC | 185 ms
93,660 KB |
testcase_05 | AC | 181 ms
93,716 KB |
testcase_06 | AC | 184 ms
93,704 KB |
testcase_07 | AC | 824 ms
211,184 KB |
testcase_08 | AC | 848 ms
210,264 KB |
testcase_09 | AC | 831 ms
210,028 KB |
testcase_10 | AC | 830 ms
210,760 KB |
testcase_11 | AC | 842 ms
210,128 KB |
testcase_12 | AC | 882 ms
211,256 KB |
testcase_13 | AC | 885 ms
211,336 KB |
testcase_14 | AC | 872 ms
210,120 KB |
testcase_15 | AC | 855 ms
210,764 KB |
testcase_16 | AC | 873 ms
210,380 KB |
testcase_17 | AC | 867 ms
211,252 KB |
testcase_18 | AC | 858 ms
212,140 KB |
testcase_19 | AC | 862 ms
211,220 KB |
testcase_20 | AC | 848 ms
210,732 KB |
testcase_21 | AC | 852 ms
212,268 KB |
testcase_22 | AC | 875 ms
221,524 KB |
testcase_23 | AC | 875 ms
222,544 KB |
testcase_24 | AC | 857 ms
223,180 KB |
testcase_25 | AC | 874 ms
219,452 KB |
testcase_26 | AC | 854 ms
218,556 KB |
testcase_27 | AC | 771 ms
211,352 KB |
testcase_28 | AC | 791 ms
211,988 KB |
testcase_29 | AC | 781 ms
211,952 KB |
ソースコード
from random import getrandbits, randrange from string import ascii_lowercase, ascii_uppercase import sys from math import ceil, floor, sqrt, pi, factorial, gcd, log, log10, log2, inf, cos, sin from copy import deepcopy, copy from collections import Counter, deque, defaultdict from heapq import heapify, heappop, heappush from itertools import ( accumulate, chain, product, combinations, combinations_with_replacement, permutations, ) from bisect import bisect, bisect_left, bisect_right from functools import lru_cache, reduce from decimal import Decimal, getcontext from typing import List, Tuple, Optional class Inf: def __init__(self, value): self.value = value def __lt__(self, other): return False def __le__(self, other): if isinstance(other, Inf): return True return False def __gt__(self, other): if isinstance(other, Inf): return False return True def __ge__(self, other): return True def __eq__(self, other): return isinstance(other, Inf) and self.value == other.value def __repr__(self): return f"{self.value}" def __add__(self, other): return Inf(self.value) if isinstance(other, Inf) else self def __sub__(self, other): return Inf(self.value) if isinstance(other, Inf) else self def __mul__(self, other): return Inf(self.value) if isinstance(other, Inf) else self def ceil_div(a, b): return (a + b - 1) // b def isqrt(num): res = int(sqrt(num)) while res * res > num: res -= 1 while (res + 1) * (res + 1) <= num: res += 1 return res def int1(s): return int(s) - 1 from types import GeneratorType def bootstrap(f, stack=[]): def wrapped(*args, **kwargs): if stack: return f(*args, **kwargs) else: to = f(*args, **kwargs) while True: if type(to) is GeneratorType: stack.append(to) to = next(to) else: stack.pop() if not stack: break to = stack[-1].send(to) return to return wrapped import sys import os from io import BytesIO, IOBase BUFSIZE = 8192 class FastIO(IOBase): newlines = 0 def __init__(self, file): self._fd = file.fileno() self.buffer = BytesIO() self.writable = "x" in file.mode or "r" not in file.mode self.write = self.buffer.write if self.writable else None def read(self): while True: b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE)) if not b: break ptr = self.buffer.tell() self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr) self.newlines = 0 return self.buffer.read() def readline(self): while self.newlines == 0: b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE)) self.newlines = b.count(b"\n") + (not b) ptr = self.buffer.tell() self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr) self.newlines -= 1 return self.buffer.readline() def flush(self): if self.writable: os.write(self._fd, self.buffer.getvalue()) self.buffer.truncate(0), self.buffer.seek(0) class IOWrapper(IOBase): def __init__(self, file): self.buffer = FastIO(file) self.flush = self.buffer.flush self.writable = self.buffer.writable self.write = lambda s: self.buffer.write(s.encode("ascii")) self.read = lambda: self.buffer.read().decode("ascii") self.readline = lambda: self.buffer.readline().decode("ascii") sys.stdin, sys.stdout = IOWrapper(sys.stdin), IOWrapper(sys.stdout) input = lambda: sys.stdin.readline().rstrip("\r\n") print = lambda *args, end="\n", sep=" ": sys.stdout.write( sep.join(map(str, args)) + end ) def II(): return int(input()) def MII(base=0): return map(lambda s: int(s) - base, input().split()) def LII(base=0): return list(MII(base)) def NA(): n = II() a = LII() return n, a def read_graph(n, m, base=0, directed=False, return_edges=False): g = [[] for _ in range(n)] edges = [] for _ in range(m): a, b = MII(base) if return_edges: edges.append((a, b)) g[a].append(b) if not directed: g[b].append(a) if return_edges: return g, edges return g def read_graph_with_weight(n, m, base=0, directed=False, return_edges=False): g = [[] for _ in range(n)] edges = [] for _ in range(m): a, b, w = MII() a, b = a - base, b - base if return_edges: edges.append((a, b, w)) g[a].append((b, w)) if not directed: g[b].append((a, w)) if return_edges: return g, edges return g def read_edges_from_ps(): ps = LII(1) edges = [] for i, p in enumerate(ps, 1): edges.append((p, i)) return edges def yes(res): print("Yes" if res else "No") def YES(res): print("YES" if res else "NO") def cmin(dp, i, x): if x < dp[i]: dp[i] = x def cmax(dp, i, x): if x > dp[i]: dp[i] = x def alp_a_to_i(s): return ord(s) - ord("a") def alp_A_to_i(s): return ord(s) - ord("A") def alp_i_to_a(i): return chr(ord("a") + i) def alp_i_to_A(i): return chr(ord("A") + i) d4 = [(1, 0), (0, 1), (-1, 0), (0, -1)] d8 = [(1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0), (-1, -1), (0, -1), (1, -1)] def ranges(n, m): return ((i, j) for i in range(n) for j in range(m)) def rangess(a, b, c): return ((i, j, k) for i in range(a) for j in range(b) for k in range(c)) def valid(i, j, n, m): return 0 <= i < n and 0 <= j < m def ninj(i, j, n, m): return [(i + di, j + dj) for di, dj in d4 if valid(i + di, j + dj, n, m)] def gen(x, *args): if len(args) == 1: return [x] * args[0] if len(args) == 2: return [[x] * args[1] for _ in [0] * args[0]] if len(args) == 3: return [[[x] * args[2] for _ in [0] * args[1]] for _ in [0] * args[0]] if len(args) == 4: return [ [[[x] * args[3] for _ in [0] * args[2]] for _ in [0] * args[1]] for _ in [0] * args[0] ] list2d = lambda a, b, v: [[v] * b for _ in range(a)] list3d = lambda a, b, c, v: [[[v] * c for _ in range(b)] for _ in range(a)] class Debug: def __init__(self, debug=False): self.debug = debug cur_path = os.path.dirname(os.path.abspath(__file__)) self.local = os.path.exists(cur_path + "/.cph") def get_ic(self): if self.debug and self.local: from icecream import ic return ic else: return lambda *args, **kwargs: ... def pairwise(a): n = len(a) for i in range(n - 1): yield a[i], a[i + 1] def factorial(n): return reduce(lambda x, y: x * y, range(1, n + 1)) ic = Debug(1).get_ic() inf = Inf(-1) class PrefixSum: def __init__(self, a): self.n = len(a) self.sum = [0] * (self.n + 1) for i in range(1, self.n + 1): self.sum[i] = self.sum[i - 1] + a[i - 1] def __getitem__(self, key): if isinstance(key, slice): start = key.start if key.start is not None else 0 stop = key.stop if key.stop is not None else self.n - 1 return self.get_sum(start, stop) return self.sum[key + 1] def __iter__(self): return iter(self.sum) def __len__(self): return self.n def get_sum(self, l, r): if l > r: return 0 return self.sum[r + 1] - self.sum[l] def __repr__(self): return str(self.sum) class SparseTable: def __init__(self, data: list, func=min): self.func = func self.st = st = [list(data)] i, N = 1, len(st[0]) while 2 * i <= N + 1: qz = st[-1] st.append([func(qz[j], qz[j + i]) for j in range(N - 2 * i + 1)]) i <<= 1 def query(self, begin: int, end: int): lg = (end - begin + 1).bit_length() - 1 return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1]) class EulerTour: def __init__(self, vertex_num: int): self.N = vertex_num self.edge: list[tuple[int, int]] = [] self.G: list[list[tuple[int, int]]] = [[] for _ in range(vertex_num)] self.lca = None def add_edge(self, u: int, v: int): eid = len(self.edge) self.edge.append((u, v)) self.G[u].append((v, eid)) self.G[v].append((u, eid)) def build(self, root=0): N, G = self.N, self.G self.root = root tour = [] depth = [] node_in = [-1] * N node_out = [-1] * N node_depth = [-1] * N edge_in = [-1] * (N - 1) edge_out = [-1] * (N - 1) parent = [-1] * N stk = [(root, 0, -1)] t = -1 while stk: t += 1 v, d, ei = stk.pop() if node_in[v] < 0: node_in[v] = t node_depth[v] = d if ei >= 0: edge_in[ei] = t tour.append(v) depth.append(d) is_leaf = True for nv, ne in G[v]: if parent[nv] >= 0: continue parent[nv] = v stk.append((v, d, ne)) stk.append((nv, d + 1, ne)) is_leaf = False if is_leaf: node_out[v] = t + 1 if ei >= 0: edge_out[ei] = t + 1 else: node_out[v] = t + 1 edge_out[ei] = t + 1 tour.append(v) depth.append(d) tour.append(-1) depth.append(-1) self.tour = tour self.depth = depth self.node_in = node_in self.node_out = node_out self.node_depth = node_depth self.edge_in = edge_in self.edge_out = edge_out self.parent = parent class __LCA: def __init__(self, tour: list[int], tour_depth: list[int], node_in: list[int]): data = [(d, v) for d, v in zip(tour_depth, tour)] self._st = SparseTable(data, func=lambda x, y: x if x <= y else y) self._node_in = node_in def get(self, u, v): node_in = self._node_in l, r = node_in[u], node_in[v] if l > r: l, r = r, l return self._st.query(l, r)[1] def use_LCA(self): if self.lca is None: self.lca = self.__LCA(self.tour, self.depth, self.node_in) return self.lca class __AuxiliaryTree(dict): def __init__( self, vertex_group_id: list[int], special_nodes: list[int], node_in: list[int], node_out: list[int], lca, parent, ): V: dict[int, list[int]] = dict() if not ((vertex_group_id is None) ^ (special_nodes is None)): raise ValueError if vertex_group_id is not None: for v, g in enumerate(vertex_group_id): if g not in V: V[g] = [] V[g].append(v) if special_nodes is not None: V[1] = special_nodes[::] for k, vv in V.items(): vv.sort(key=lambda v: node_in[v]) for i in range(1, len(vv)): vv.append(lca.get(vv[i - 1], vv[i])) vv = sorted(set(vv), key=lambda v: node_in[v]) G: dict[int, list[int]] = dict() P: dict[int, int] = dict() stk: list[int] = [] for v in vv: while stk and node_out[stk[-1]] <= node_out[v]: stk.pop() if stk: p = stk[-1] if p not in G: G[p] = [] G[p].append(v) P[v] = p stk.append(v) self[k] = (G, P, vv) def use_AuxiliaryTree( self, vertex_group_id: list[int] = None, special_nodes: list[int] = None ) -> dict[int, tuple[dict[int, list[int]], dict[int, int], list[int]]]: return self.__AuxiliaryTree( vertex_group_id, special_nodes, self.node_in, self.node_out, self.use_LCA(), self.parent, ) class Tree: def __init__(self, g=None, edges=None, root=0, vals=[]): if edges is not None: self.n = n = len(edges) + 1 self.g = g = [[] for _ in range(n)] for u, v in edges: self.g[u].append(v) self.g[v].append(u) else: self.n = n = len(g) self.g = g self.root = root self.parent = parent = [-1] * n stk = [root] self.order = order = [root] self.depth = depth = [0] * n while stk: u = stk.pop() for v in g[u]: if v != root and parent[v] == -1: depth[v] = depth[u] + 1 parent[v] = u stk.append(v) order.append(v) n = II() T = EulerTour(n) g, edges = read_graph_with_weight(n, n - 1, 0, return_edges=True) for u, v, w in edges: T.add_edge(u, v) T.build() lca = T.use_LCA() t = Tree(edges=[(u, v) for u, v, _ in edges]) order, depth, parent = t.order, t.depth, t.parent a = [0] * n for u, v, w in edges: child = u if depth[u] > depth[v] else v a[child] = w b = a[::] vst = [0] * n for u in order[1:]: p = parent[u] b[u] += b[p] def query(u, v): return b[u] + b[v] - 2 * b[lca.get(u, v)] for _ in range(II()): a = LII()[1:] res = 0 for ai, (children, ps, tour) in T.use_AuxiliaryTree(special_nodes=a).items(): for u in tour[::-1]: for v in children.get(u, []): res += query(u, v) print(res)