結果
問題 | No.922 東北きりきざむたん |
ユーザー | terasa |
提出日時 | 2022-11-03 23:31:35 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 999 ms / 2,000 ms |
コード長 | 6,562 bytes |
コンパイル時間 | 126 ms |
コンパイル使用メモリ | 81,888 KB |
実行使用メモリ | 136,784 KB |
最終ジャッジ日時 | 2024-07-18 05:30:16 |
合計ジャッジ時間 | 15,696 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 66 ms
69,248 KB |
testcase_01 | AC | 67 ms
69,504 KB |
testcase_02 | AC | 61 ms
69,120 KB |
testcase_03 | AC | 61 ms
69,248 KB |
testcase_04 | AC | 81 ms
77,952 KB |
testcase_05 | AC | 70 ms
73,728 KB |
testcase_06 | AC | 93 ms
78,420 KB |
testcase_07 | AC | 86 ms
77,952 KB |
testcase_08 | AC | 87 ms
78,296 KB |
testcase_09 | AC | 495 ms
109,584 KB |
testcase_10 | AC | 428 ms
91,136 KB |
testcase_11 | AC | 472 ms
103,540 KB |
testcase_12 | AC | 316 ms
113,552 KB |
testcase_13 | AC | 242 ms
89,068 KB |
testcase_14 | AC | 556 ms
128,268 KB |
testcase_15 | AC | 218 ms
128,236 KB |
testcase_16 | AC | 823 ms
131,684 KB |
testcase_17 | AC | 821 ms
132,520 KB |
testcase_18 | AC | 854 ms
133,584 KB |
testcase_19 | AC | 823 ms
132,396 KB |
testcase_20 | AC | 801 ms
132,532 KB |
testcase_21 | AC | 822 ms
130,024 KB |
testcase_22 | AC | 848 ms
130,612 KB |
testcase_23 | AC | 960 ms
133,568 KB |
testcase_24 | AC | 999 ms
134,980 KB |
testcase_25 | AC | 721 ms
133,992 KB |
testcase_26 | AC | 729 ms
134,192 KB |
testcase_27 | AC | 732 ms
134,768 KB |
testcase_28 | AC | 253 ms
136,784 KB |
testcase_29 | AC | 675 ms
135,408 KB |
ソースコード
from typing import List, Tuple, Callable, TypeVar from typing import List, Tuple, Optional import sys import itertools import heapq import bisect from collections import deque, defaultdict from functools import lru_cache, cmp_to_key input = sys.stdin.readline # for AtCoder Easy test if __file__ != 'prog.py': sys.setrecursionlimit(10 ** 6) def readints(): return map(int, input().split()) def readlist(): return list(readints()) def readstr(): return input().rstrip() T = TypeVar('T') class Rerooting: # reference: https://null-mn.hatenablog.com/entry/2020/04/14/124151 # 適当な頂点vを根とする部分木に対して計算される値dp_vが、vの子c1, c2, ... ckを用いて # 下記のように表すことができる # dp_v = g(merge(f(dp_c1,c1), f(dp_c2,c2), ..., f(dp_ck,ck)), v) def __init__(self, N: int, E: List[Tuple[int, int]], f: Callable[[T, int, int, int], T], g: Callable[[T, int], T], merge: Callable[[T, T], T], e: T): self.N = N self.E = E self.f = f self.g = g self.merge = merge self.e = e self.ret = [self.e] * self.N self.dp = [[self.e for _ in range(len(self.E[v]))] for v in range(self.N)] def _dfs1(self, root): stack = [(root, -1)] while stack: v, p = stack.pop() if v < 0: v = ~v acc = self.e for i, (c, d) in enumerate(self.E[v]): if d == p: continue self.dp[v][i] = self.ret[d] acc = self.merge(acc, self.f(self.ret[d], v, d, c)) self.ret[v] = self.g(acc, v) continue stack.append((~v, p)) for i, (c, d) in enumerate(self.E[v]): if d == p: continue stack.append((d, v)) def _dfs2(self, root): stack = [(root, -1, self.e)] while stack: v, p, from_par = stack.pop() for i, (c, d) in enumerate(self.E[v]): if d == p: self.dp[v][i] = from_par break ch = len(self.E[v]) Sr = [self.e] * (ch + 1) for i in range(ch, 0, -1): c, d = self.E[v][i - 1] Sr[i - 1] = self.merge(Sr[i], self.f(self.dp[v][i - 1], v, d, c)) Sl = self.e for i, (c, d) in enumerate(self.E[v]): if d != p: val = self.merge(Sl, Sr[i + 1]) stack.append((d, v, self.g(val, v))) Sl = self.merge(Sl, self.f(self.dp[v][i], v, d, c)) def calculate(self, root=0): self._dfs1(root) self._dfs2(root) def solve(self, v): ans = self.e for i, (c, d) in enumerate(self.E[v]): ans = self.merge(ans, self.f(self.dp[v][i], v, d, c)) return self.g(ans, v) class UnionFind: def __init__(self, N): self.N = N self.par = [-1] * N def find(self, x): if self.par[x] < 0: return x else: self.par[x] = self.find(self.par[x]) return self.par[x] def unite(self, x, y): x = self.find(x) y = self.find(y) if x == y: return False if self.par[x] > self.par[y]: x, y = y, x self.par[x] += self.par[y] self.par[y] = x return True def same(self, x, y): return self.find(x) == self.find(y) def size(self, x): return -self.par[self.find(x)] def roots(self): return [i for i in range(self.N) if self.par[i] < 0] class LCA: def __init__(self, N, E): self.N = N self.E = E self.K = N.bit_length() self.par = [[-1 for _ in range(N)] for _ in range(self.K)] self.depth = [None] * N def _dfs(self, root): self.depth[root] = 0 stack = [(root, -1)] while stack: v, p = stack.pop() if not p < 0: self.par[0][v] = p self.depth[v] = self.depth[p] + 1 for _, dest in self.E[v]: if dest == p: continue stack.append((dest, v)) def calculate(self): for k in range(self.K - 1): for i in range(self.N): if self.par[k][i] < 0: continue self.par[k + 1][i] = self.par[k][self.par[k][i]] def la(self, v, x): for k in range(self.K): if x & (1 << k): v = self.par[k][v] return v def lca(self, u, v): if self.depth[u] > self.depth[v]: u, v = v, u d = self.depth[v] - self.depth[u] v = self.la(v, d) if u == v: return u for k in range(self.K)[::-1]: if self.par[k][u] != self.par[k][v]: u = self.par[k][u] v = self.par[k][v] return self.par[0][v] def dist(self, u, v): return self.depth[u] + self.depth[v] - 2 * self.depth[self.lca(u, v)] def jump(self, u, v, x): lca = self.lca(u, v) d1 = self.depth[u] - self.depth[lca] d2 = self.depth[v] - self.depth[lca] if d1 + d2 < x: return -1 if x <= d1: return self.la(u, x) return self.la(v, d1 + d2 - x) N, M, Q = readints() E = [[] for _ in range(N)] uf = UnionFind(N) D = N for _ in range(M): u, v = readints() u -= 1 v -= 1 E[u].append((1, v)) E[v].append((1, u)) if uf.unite(u, v) is True: D -= 1 query = [] todo = [] cnt = [0] * N for _ in range(Q): a, b = readints() a -= 1 b -= 1 if uf.same(a, b) is False: cnt[a] += 1 cnt[b] += 1 else: todo.append((a, b)) def f(a, v, ch, cost): return (a[0] + a[1], a[1]) def g(a, v): return (a[0], a[1] + cnt[v]) def merge(a, b): return (a[0] + b[0], a[1] + b[1]) gidx = {} for i, v in enumerate(uf.roots()): gidx[v] = i V = [[] for _ in range(D)] for i in range(N): V[gidx[uf.find(i)]].append(i) solver = Rerooting(N, E, f, g, merge, (0, 0)) lca = LCA(N, E) INF = 1 << 30 cost = [INF] * D for i in range(D): solver.calculate(root=V[i][0]) lca._dfs(V[i][0]) for v in V[i]: cost[i] = min(cost[i], solver.solve(v)[0]) lca.calculate() ans = sum(cost) for s, t in todo: ans += lca.dist(s, t) print(ans)