結果
問題 | No.901 K-ary εxtrεεmε |
ユーザー | StanMarsh |
提出日時 | 2024-02-27 00:18:53 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,065 ms / 3,000 ms |
コード長 | 15,219 bytes |
コンパイル時間 | 365 ms |
コンパイル使用メモリ | 82,160 KB |
実行使用メモリ | 236,260 KB |
最終ジャッジ日時 | 2024-09-29 11:51:10 |
合計ジャッジ時間 | 27,991 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 708 ms
236,260 KB |
testcase_01 | AC | 153 ms
92,152 KB |
testcase_02 | AC | 214 ms
93,148 KB |
testcase_03 | AC | 207 ms
93,588 KB |
testcase_04 | AC | 209 ms
93,276 KB |
testcase_05 | AC | 211 ms
93,124 KB |
testcase_06 | AC | 212 ms
93,368 KB |
testcase_07 | AC | 1,045 ms
210,560 KB |
testcase_08 | AC | 1,012 ms
209,140 KB |
testcase_09 | AC | 994 ms
209,268 KB |
testcase_10 | AC | 1,025 ms
211,160 KB |
testcase_11 | AC | 1,007 ms
209,628 KB |
testcase_12 | AC | 1,035 ms
210,780 KB |
testcase_13 | AC | 1,050 ms
211,576 KB |
testcase_14 | AC | 1,018 ms
211,936 KB |
testcase_15 | AC | 1,014 ms
211,032 KB |
testcase_16 | AC | 1,047 ms
209,092 KB |
testcase_17 | AC | 1,053 ms
210,748 KB |
testcase_18 | AC | 1,029 ms
211,900 KB |
testcase_19 | AC | 1,027 ms
211,136 KB |
testcase_20 | AC | 1,032 ms
210,868 KB |
testcase_21 | AC | 1,065 ms
211,492 KB |
testcase_22 | AC | 1,061 ms
220,508 KB |
testcase_23 | AC | 1,029 ms
220,992 KB |
testcase_24 | AC | 1,037 ms
222,172 KB |
testcase_25 | AC | 1,032 ms
221,088 KB |
testcase_26 | AC | 1,029 ms
220,492 KB |
testcase_27 | AC | 936 ms
210,716 KB |
testcase_28 | AC | 923 ms
212,416 KB |
testcase_29 | AC | 913 ms
211,348 KB |
ソースコード
from random import getrandbits, randrange from string import ascii_lowercase, ascii_uppercase import sys from math import ceil, floor, sqrt, pi, factorial, gcd, log, log10, log2, inf, cos, sin from copy import deepcopy, copy from collections import Counter, deque, defaultdict from heapq import heapify, heappop, heappush from itertools import ( accumulate, chain, product, combinations, combinations_with_replacement, permutations, ) from bisect import bisect, bisect_left, bisect_right from functools import lru_cache, reduce from decimal import Decimal, getcontext from typing import List, Tuple, Optional class Inf: def __init__(self, value): self.value = value def __lt__(self, other): return False def __le__(self, other): if isinstance(other, Inf): return True return False def __gt__(self, other): if isinstance(other, Inf): return False return True def __ge__(self, other): return True def __eq__(self, other): return isinstance(other, Inf) and self.value == other.value def __repr__(self): return f"{self.value}" def __add__(self, other): return Inf(self.value) if isinstance(other, Inf) else self def __sub__(self, other): return Inf(self.value) if isinstance(other, Inf) else self def __mul__(self, other): return Inf(self.value) if isinstance(other, Inf) else self def ceil_div(a, b): return (a + b - 1) // b def isqrt(num): res = int(sqrt(num)) while res * res > num: res -= 1 while (res + 1) * (res + 1) <= num: res += 1 return res def int1(s): return int(s) - 1 from types import GeneratorType def bootstrap(f, stack=[]): def wrapped(*args, **kwargs): if stack: return f(*args, **kwargs) else: to = f(*args, **kwargs) while True: if type(to) is GeneratorType: stack.append(to) to = next(to) else: stack.pop() if not stack: break to = stack[-1].send(to) return to return wrapped import sys import os from io import BytesIO, IOBase BUFSIZE = 8192 class FastIO(IOBase): newlines = 0 def __init__(self, file): self._fd = file.fileno() self.buffer = BytesIO() self.writable = "x" in file.mode or "r" not in file.mode self.write = self.buffer.write if self.writable else None def read(self): while True: b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE)) if not b: break ptr = self.buffer.tell() self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr) self.newlines = 0 return self.buffer.read() def readline(self): while self.newlines == 0: b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE)) self.newlines = b.count(b"\n") + (not b) ptr = self.buffer.tell() self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr) self.newlines -= 1 return self.buffer.readline() def flush(self): if self.writable: os.write(self._fd, self.buffer.getvalue()) self.buffer.truncate(0), self.buffer.seek(0) class IOWrapper(IOBase): def __init__(self, file): self.buffer = FastIO(file) self.flush = self.buffer.flush self.writable = self.buffer.writable self.write = lambda s: self.buffer.write(s.encode("ascii")) self.read = lambda: self.buffer.read().decode("ascii") self.readline = lambda: self.buffer.readline().decode("ascii") sys.stdin, sys.stdout = IOWrapper(sys.stdin), IOWrapper(sys.stdout) input = lambda: sys.stdin.readline().rstrip("\r\n") print = lambda *args, end="\n", sep=" ": sys.stdout.write( sep.join(map(str, args)) + end ) def II(): return int(input()) def MII(base=0): return map(lambda s: int(s) - base, input().split()) def LII(base=0): return list(MII(base)) def NA(): n = II() a = LII() return n, a def read_graph(n, m, base=0, directed=False, return_edges=False): g = [[] for _ in range(n)] edges = [] for _ in range(m): a, b = MII(base) if return_edges: edges.append((a, b)) g[a].append(b) if not directed: g[b].append(a) if return_edges: return g, edges return g def read_graph_with_weight(n, m, base=0, directed=False, return_edges=False): g = [[] for _ in range(n)] edges = [] for _ in range(m): a, b, w = MII() a, b = a - base, b - base if return_edges: edges.append((a, b, w)) g[a].append((b, w)) if not directed: g[b].append((a, w)) if return_edges: return g, edges return g def read_edges_from_ps(): ps = LII(1) edges = [] for i, p in enumerate(ps, 1): edges.append((p, i)) return edges def yes(res): print("Yes" if res else "No") def YES(res): print("YES" if res else "NO") def cmin(dp, i, x): if x < dp[i]: dp[i] = x def cmax(dp, i, x): if x > dp[i]: dp[i] = x def alp_a_to_i(s): return ord(s) - ord("a") def alp_A_to_i(s): return ord(s) - ord("A") def alp_i_to_a(i): return chr(ord("a") + i) def alp_i_to_A(i): return chr(ord("A") + i) d4 = [(1, 0), (0, 1), (-1, 0), (0, -1)] d8 = [(1, 0), (1, 1), (0, 1), (-1, 1), (-1, 0), (-1, -1), (0, -1), (1, -1)] def ranges(n, m): return ((i, j) for i in range(n) for j in range(m)) def rangess(a, b, c): return ((i, j, k) for i in range(a) for j in range(b) for k in range(c)) def valid(i, j, n, m): return 0 <= i < n and 0 <= j < m def ninj(i, j, n, m): return [(i + di, j + dj) for di, dj in d4 if valid(i + di, j + dj, n, m)] def gen(x, *args): if len(args) == 1: return [x] * args[0] if len(args) == 2: return [[x] * args[1] for _ in [0] * args[0]] if len(args) == 3: return [[[x] * args[2] for _ in [0] * args[1]] for _ in [0] * args[0]] if len(args) == 4: return [ [[[x] * args[3] for _ in [0] * args[2]] for _ in [0] * args[1]] for _ in [0] * args[0] ] list2d = lambda a, b, v: [[v] * b for _ in range(a)] list3d = lambda a, b, c, v: [[[v] * c for _ in range(b)] for _ in range(a)] class Debug: def __init__(self, debug=False): self.debug = debug cur_path = os.path.dirname(os.path.abspath(__file__)) self.local = os.path.exists(cur_path + "/.cph") def get_ic(self): if self.debug and self.local: from icecream import ic return ic else: return lambda *args, **kwargs: ... def pairwise(a): n = len(a) for i in range(n - 1): yield a[i], a[i + 1] def factorial(n): return reduce(lambda x, y: x * y, range(1, n + 1)) ic = Debug(1).get_ic() inf = Inf(-1) class PrefixSum: def __init__(self, a): self.n = len(a) self.sum = [0] * (self.n + 1) for i in range(1, self.n + 1): self.sum[i] = self.sum[i - 1] + a[i - 1] def __getitem__(self, key): if isinstance(key, slice): start = key.start if key.start is not None else 0 stop = key.stop if key.stop is not None else self.n - 1 return self.get_sum(start, stop) return self.sum[key + 1] def __iter__(self): return iter(self.sum) def __len__(self): return self.n def get_sum(self, l, r): if l > r: return 0 return self.sum[r + 1] - self.sum[l] def __repr__(self): return str(self.sum) class SparseTable: def __init__(self, data: list, func=min): self.func = func self.st = st = [list(data)] i, N = 1, len(st[0]) while 2 * i <= N + 1: qz = st[-1] st.append([func(qz[j], qz[j + i]) for j in range(N - 2 * i + 1)]) i <<= 1 def query(self, begin: int, end: int): lg = (end - begin + 1).bit_length() - 1 return self.func(self.st[lg][begin], self.st[lg][end - (1 << lg) + 1]) class EulerTour: def __init__(self, vertex_num: int): self.N = vertex_num self.edge: list[tuple[int, int]] = [] self.G: list[list[tuple[int, int]]] = [[] for _ in range(vertex_num)] self.lca = None def add_edge(self, u: int, v: int): eid = len(self.edge) self.edge.append((u, v)) self.G[u].append((v, eid)) self.G[v].append((u, eid)) def build(self, root=0): N, G = self.N, self.G self.root = root tour = [] depth = [] node_in = [-1] * N node_out = [-1] * N node_depth = [-1] * N edge_in = [-1] * (N - 1) edge_out = [-1] * (N - 1) parent = [-1] * N stk = [(root, 0, -1)] t = -1 while stk: t += 1 v, d, ei = stk.pop() if node_in[v] < 0: node_in[v] = t node_depth[v] = d if ei >= 0: edge_in[ei] = t tour.append(v) depth.append(d) is_leaf = True for nv, ne in G[v]: if parent[nv] >= 0: continue parent[nv] = v stk.append((v, d, ne)) stk.append((nv, d + 1, ne)) is_leaf = False if is_leaf: node_out[v] = t + 1 if ei >= 0: edge_out[ei] = t + 1 else: node_out[v] = t + 1 edge_out[ei] = t + 1 tour.append(v) depth.append(d) tour.append(-1) depth.append(-1) self.tour = tour self.depth = depth self.node_in = node_in self.node_out = node_out self.node_depth = node_depth self.edge_in = edge_in self.edge_out = edge_out self.parent = parent class __LCA: def __init__(self, tour: list[int], tour_depth: list[int], node_in: list[int]): data = [(d, v) for d, v in zip(tour_depth, tour)] self._st = SparseTable(data, func=lambda x, y: x if x <= y else y) self._node_in = node_in def get(self, u, v): node_in = self._node_in l, r = node_in[u], node_in[v] if l > r: l, r = r, l return self._st.query(l, r)[1] def use_LCA(self): if self.lca is None: self.lca = self.__LCA(self.tour, self.depth, self.node_in) return self.lca class __AuxiliaryTree(dict): def __init__( self, vertex_group_id: list[int], special_nodes: list[int], node_in: list[int], node_out: list[int], lca, parent, ): V: dict[int, list[int]] = dict() if not ((vertex_group_id is None) ^ (special_nodes is None)): raise ValueError if vertex_group_id is not None: for v, g in enumerate(vertex_group_id): if g not in V: V[g] = [] V[g].append(v) if special_nodes is not None: V[1] = special_nodes[::] for k, vv in V.items(): vv.sort(key=lambda v: node_in[v]) for i in range(1, len(vv)): vv.append(lca.get(vv[i - 1], vv[i])) vv = sorted(set(vv), key=lambda v: node_in[v]) G: dict[int, list[int]] = dict() P: dict[int, int] = dict() stk: list[int] = [] for v in vv: while stk and node_out[stk[-1]] <= node_out[v]: stk.pop() if stk: p = stk[-1] if p not in G: G[p] = [] G[p].append(v) P[v] = p stk.append(v) self[k] = (G, P, vv) def use_AuxiliaryTree( self, vertex_group_id: list[int] = None, special_nodes: list[int] = None ) -> dict[int, tuple[dict[int, list[int]], dict[int, int], list[int]]]: return self.__AuxiliaryTree( vertex_group_id, special_nodes, self.node_in, self.node_out, self.use_LCA(), self.parent, ) def MyEulerTour(n, G, i0=0): P = [-1] * n stack = [~i0, i0] ct = -1 depth = [-1] * n nodein = [-1] * n nodeout = [-1] * n ET = [] de = 0 while stack: i = stack.pop() if i < 0: ET.append(P[~i]) nodeout[~i] = ct de -= 1 continue if i >= 0: ct += 1 ET.append(i) if nodein[i] == -1: nodein[i] = ct depth[i] = de de += 1 for v, _ in G[i][::-1]: if v == P[i]: continue P[v] = i stack.append(~v) stack.append(v) return ET, nodein, nodeout, depth class Tree: def __init__(self, g=None, edges=None, root=0, vals=[]): if edges is not None: self.n = n = len(edges) + 1 self.g = g = [[] for _ in range(n)] for u, v in edges: self.g[u].append(v) self.g[v].append(u) else: self.n = n = len(g) self.g = g self.root = root self.parent = parent = [-1] * n stk = [root] self.order = order = [root] self.depth = depth = [0] * n while stk: u = stk.pop() for v in g[u]: if v != root and parent[v] == -1: depth[v] = depth[u] + 1 parent[v] = u stk.append(v) order.append(v) n = II() T = EulerTour(n) g, edges = read_graph_with_weight(n, n - 1, 0, return_edges=True) for u, v, w in edges: T.add_edge(u, v) T.build() lca = T.use_LCA() t = Tree(edges=[(u, v) for u, v, _ in edges]) depth, parent = t.depth, t.parent a = [0] * n for u, v, w in edges: child = u if depth[u] > depth[v] else v a[child] = w b = a[::] vst = [0] * n for u in T.tour: if u < 0 or vst[u] == 1: continue vst[u] = 1 p = parent[u] if p != -1: b[u] += b[p] def query(u, v): return b[u] + b[v] - 2 * b[lca.get(u, v)] for _ in range(II()): a = LII()[1:] res = 0 for ai, (children, ps, tour) in T.use_AuxiliaryTree(special_nodes=a).items(): for u in tour[::-1]: for v in children.get(u, []): res += query(u, v) print(res)