結果
問題 | No.922 東北きりきざむたん |
ユーザー | terasa |
提出日時 | 2022-11-03 23:31:35 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,205 ms / 2,000 ms |
コード長 | 6,562 bytes |
コンパイル時間 | 378 ms |
コンパイル使用メモリ | 87,044 KB |
実行使用メモリ | 141,376 KB |
最終ジャッジ日時 | 2023-09-25 06:33:44 |
合計ジャッジ時間 | 21,381 ms |
ジャッジサーバーID (参考情報) |
judge14 / judge11 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 168 ms
80,176 KB |
testcase_01 | AC | 163 ms
80,332 KB |
testcase_02 | AC | 162 ms
80,140 KB |
testcase_03 | AC | 163 ms
80,176 KB |
testcase_04 | AC | 191 ms
82,044 KB |
testcase_05 | AC | 169 ms
81,916 KB |
testcase_06 | AC | 190 ms
82,700 KB |
testcase_07 | AC | 177 ms
81,976 KB |
testcase_08 | AC | 185 ms
82,280 KB |
testcase_09 | AC | 661 ms
114,516 KB |
testcase_10 | AC | 577 ms
96,408 KB |
testcase_11 | AC | 601 ms
108,468 KB |
testcase_12 | AC | 438 ms
118,428 KB |
testcase_13 | AC | 362 ms
92,704 KB |
testcase_14 | AC | 723 ms
132,428 KB |
testcase_15 | AC | 324 ms
119,456 KB |
testcase_16 | AC | 1,043 ms
137,668 KB |
testcase_17 | AC | 1,046 ms
137,536 KB |
testcase_18 | AC | 1,122 ms
137,456 KB |
testcase_19 | AC | 1,067 ms
136,632 KB |
testcase_20 | AC | 1,071 ms
137,168 KB |
testcase_21 | AC | 1,028 ms
135,788 KB |
testcase_22 | AC | 1,067 ms
135,512 KB |
testcase_23 | AC | 1,195 ms
138,416 KB |
testcase_24 | AC | 1,205 ms
141,348 KB |
testcase_25 | AC | 929 ms
139,212 KB |
testcase_26 | AC | 899 ms
140,304 KB |
testcase_27 | AC | 921 ms
138,808 KB |
testcase_28 | AC | 352 ms
126,160 KB |
testcase_29 | AC | 882 ms
141,376 KB |
ソースコード
from typing import List, Tuple, Callable, TypeVar from typing import List, Tuple, Optional import sys import itertools import heapq import bisect from collections import deque, defaultdict from functools import lru_cache, cmp_to_key input = sys.stdin.readline # for AtCoder Easy test if __file__ != 'prog.py': sys.setrecursionlimit(10 ** 6) def readints(): return map(int, input().split()) def readlist(): return list(readints()) def readstr(): return input().rstrip() T = TypeVar('T') class Rerooting: # reference: https://null-mn.hatenablog.com/entry/2020/04/14/124151 # 適当な頂点vを根とする部分木に対して計算される値dp_vが、vの子c1, c2, ... ckを用いて # 下記のように表すことができる # dp_v = g(merge(f(dp_c1,c1), f(dp_c2,c2), ..., f(dp_ck,ck)), v) def __init__(self, N: int, E: List[Tuple[int, int]], f: Callable[[T, int, int, int], T], g: Callable[[T, int], T], merge: Callable[[T, T], T], e: T): self.N = N self.E = E self.f = f self.g = g self.merge = merge self.e = e self.ret = [self.e] * self.N self.dp = [[self.e for _ in range(len(self.E[v]))] for v in range(self.N)] def _dfs1(self, root): stack = [(root, -1)] while stack: v, p = stack.pop() if v < 0: v = ~v acc = self.e for i, (c, d) in enumerate(self.E[v]): if d == p: continue self.dp[v][i] = self.ret[d] acc = self.merge(acc, self.f(self.ret[d], v, d, c)) self.ret[v] = self.g(acc, v) continue stack.append((~v, p)) for i, (c, d) in enumerate(self.E[v]): if d == p: continue stack.append((d, v)) def _dfs2(self, root): stack = [(root, -1, self.e)] while stack: v, p, from_par = stack.pop() for i, (c, d) in enumerate(self.E[v]): if d == p: self.dp[v][i] = from_par break ch = len(self.E[v]) Sr = [self.e] * (ch + 1) for i in range(ch, 0, -1): c, d = self.E[v][i - 1] Sr[i - 1] = self.merge(Sr[i], self.f(self.dp[v][i - 1], v, d, c)) Sl = self.e for i, (c, d) in enumerate(self.E[v]): if d != p: val = self.merge(Sl, Sr[i + 1]) stack.append((d, v, self.g(val, v))) Sl = self.merge(Sl, self.f(self.dp[v][i], v, d, c)) def calculate(self, root=0): self._dfs1(root) self._dfs2(root) def solve(self, v): ans = self.e for i, (c, d) in enumerate(self.E[v]): ans = self.merge(ans, self.f(self.dp[v][i], v, d, c)) return self.g(ans, v) class UnionFind: def __init__(self, N): self.N = N self.par = [-1] * N def find(self, x): if self.par[x] < 0: return x else: self.par[x] = self.find(self.par[x]) return self.par[x] def unite(self, x, y): x = self.find(x) y = self.find(y) if x == y: return False if self.par[x] > self.par[y]: x, y = y, x self.par[x] += self.par[y] self.par[y] = x return True def same(self, x, y): return self.find(x) == self.find(y) def size(self, x): return -self.par[self.find(x)] def roots(self): return [i for i in range(self.N) if self.par[i] < 0] class LCA: def __init__(self, N, E): self.N = N self.E = E self.K = N.bit_length() self.par = [[-1 for _ in range(N)] for _ in range(self.K)] self.depth = [None] * N def _dfs(self, root): self.depth[root] = 0 stack = [(root, -1)] while stack: v, p = stack.pop() if not p < 0: self.par[0][v] = p self.depth[v] = self.depth[p] + 1 for _, dest in self.E[v]: if dest == p: continue stack.append((dest, v)) def calculate(self): for k in range(self.K - 1): for i in range(self.N): if self.par[k][i] < 0: continue self.par[k + 1][i] = self.par[k][self.par[k][i]] def la(self, v, x): for k in range(self.K): if x & (1 << k): v = self.par[k][v] return v def lca(self, u, v): if self.depth[u] > self.depth[v]: u, v = v, u d = self.depth[v] - self.depth[u] v = self.la(v, d) if u == v: return u for k in range(self.K)[::-1]: if self.par[k][u] != self.par[k][v]: u = self.par[k][u] v = self.par[k][v] return self.par[0][v] def dist(self, u, v): return self.depth[u] + self.depth[v] - 2 * self.depth[self.lca(u, v)] def jump(self, u, v, x): lca = self.lca(u, v) d1 = self.depth[u] - self.depth[lca] d2 = self.depth[v] - self.depth[lca] if d1 + d2 < x: return -1 if x <= d1: return self.la(u, x) return self.la(v, d1 + d2 - x) N, M, Q = readints() E = [[] for _ in range(N)] uf = UnionFind(N) D = N for _ in range(M): u, v = readints() u -= 1 v -= 1 E[u].append((1, v)) E[v].append((1, u)) if uf.unite(u, v) is True: D -= 1 query = [] todo = [] cnt = [0] * N for _ in range(Q): a, b = readints() a -= 1 b -= 1 if uf.same(a, b) is False: cnt[a] += 1 cnt[b] += 1 else: todo.append((a, b)) def f(a, v, ch, cost): return (a[0] + a[1], a[1]) def g(a, v): return (a[0], a[1] + cnt[v]) def merge(a, b): return (a[0] + b[0], a[1] + b[1]) gidx = {} for i, v in enumerate(uf.roots()): gidx[v] = i V = [[] for _ in range(D)] for i in range(N): V[gidx[uf.find(i)]].append(i) solver = Rerooting(N, E, f, g, merge, (0, 0)) lca = LCA(N, E) INF = 1 << 30 cost = [INF] * D for i in range(D): solver.calculate(root=V[i][0]) lca._dfs(V[i][0]) for v in V[i]: cost[i] = min(cost[i], solver.solve(v)[0]) lca.calculate() ans = sum(cost) for s, t in todo: ans += lca.dist(s, t) print(ans)