結果
問題 | No.922 東北きりきざむたん |
ユーザー | terasa |
提出日時 | 2022-11-03 23:38:56 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,127 ms / 2,000 ms |
コード長 | 6,521 bytes |
コンパイル時間 | 251 ms |
コンパイル使用メモリ | 82,456 KB |
実行使用メモリ | 136,608 KB |
最終ジャッジ日時 | 2024-07-18 05:34:09 |
合計ジャッジ時間 | 18,784 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 71 ms
69,504 KB |
testcase_01 | AC | 69 ms
69,632 KB |
testcase_02 | AC | 69 ms
69,376 KB |
testcase_03 | AC | 69 ms
69,376 KB |
testcase_04 | AC | 95 ms
78,336 KB |
testcase_05 | AC | 77 ms
73,856 KB |
testcase_06 | AC | 108 ms
78,912 KB |
testcase_07 | AC | 93 ms
78,040 KB |
testcase_08 | AC | 103 ms
78,592 KB |
testcase_09 | AC | 612 ms
109,952 KB |
testcase_10 | AC | 497 ms
91,008 KB |
testcase_11 | AC | 530 ms
103,680 KB |
testcase_12 | AC | 356 ms
114,048 KB |
testcase_13 | AC | 283 ms
89,244 KB |
testcase_14 | AC | 646 ms
128,000 KB |
testcase_15 | AC | 256 ms
128,260 KB |
testcase_16 | AC | 979 ms
131,864 KB |
testcase_17 | AC | 983 ms
132,224 KB |
testcase_18 | AC | 1,010 ms
133,288 KB |
testcase_19 | AC | 959 ms
132,588 KB |
testcase_20 | AC | 951 ms
132,236 KB |
testcase_21 | AC | 959 ms
130,260 KB |
testcase_22 | AC | 973 ms
130,716 KB |
testcase_23 | AC | 1,124 ms
133,760 KB |
testcase_24 | AC | 1,127 ms
135,552 KB |
testcase_25 | AC | 849 ms
134,272 KB |
testcase_26 | AC | 854 ms
134,400 KB |
testcase_27 | AC | 859 ms
134,756 KB |
testcase_28 | AC | 290 ms
136,608 KB |
testcase_29 | AC | 787 ms
135,168 KB |
ソースコード
from typing import List, Tuple, Callable, TypeVar import sys import itertools import heapq import bisect from collections import deque, defaultdict from functools import lru_cache, cmp_to_key input = sys.stdin.readline # for AtCoder Easy test if __file__ != 'prog.py': sys.setrecursionlimit(10 ** 6) def readints(): return map(int, input().split()) def readlist(): return list(readints()) def readstr(): return input().rstrip() T = TypeVar('T') class Rerooting: # reference: https://null-mn.hatenablog.com/entry/2020/04/14/124151 # 適当な頂点vを根とする部分木に対して計算される値dp_vが、vの子c1, c2, ... ckを用いて # 下記のように表すことができる # dp_v = g(merge(f(dp_c1,c1), f(dp_c2,c2), ..., f(dp_ck,ck)), v) def __init__(self, N: int, E: List[Tuple[int, int]], f: Callable[[T, int, int, int], T], g: Callable[[T, int], T], merge: Callable[[T, T], T], e: T): self.N = N self.E = E self.f = f self.g = g self.merge = merge self.e = e self.ret = [self.e] * self.N self.dp = [[self.e for _ in range(len(self.E[v]))] for v in range(self.N)] def _dfs1(self, root): stack = [(root, -1)] while stack: v, p = stack.pop() if v < 0: v = ~v acc = self.e for i, (c, d) in enumerate(self.E[v]): if d == p: continue self.dp[v][i] = self.ret[d] acc = self.merge(acc, self.f(self.ret[d], v, d, c)) self.ret[v] = self.g(acc, v) continue stack.append((~v, p)) for i, (c, d) in enumerate(self.E[v]): if d == p: continue stack.append((d, v)) def _dfs2(self, root): stack = [(root, -1, self.e)] while stack: v, p, from_par = stack.pop() for i, (c, d) in enumerate(self.E[v]): if d == p: self.dp[v][i] = from_par break ch = len(self.E[v]) Sr = [self.e] * (ch + 1) for i in range(ch, 0, -1): c, d = self.E[v][i - 1] Sr[i - 1] = self.merge(Sr[i], self.f(self.dp[v][i - 1], v, d, c)) Sl = self.e for i, (c, d) in enumerate(self.E[v]): if d != p: val = self.merge(Sl, Sr[i + 1]) stack.append((d, v, self.g(val, v))) Sl = self.merge(Sl, self.f(self.dp[v][i], v, d, c)) def calculate(self, root=0): self._dfs1(root) self._dfs2(root) def solve(self, v): ans = self.e for i, (c, d) in enumerate(self.E[v]): ans = self.merge(ans, self.f(self.dp[v][i], v, d, c)) return self.g(ans, v) class UnionFind: def __init__(self, N): self.N = N self.par = [-1] * N def find(self, x): if self.par[x] < 0: return x else: self.par[x] = self.find(self.par[x]) return self.par[x] def unite(self, x, y): x = self.find(x) y = self.find(y) if x == y: return False if self.par[x] > self.par[y]: x, y = y, x self.par[x] += self.par[y] self.par[y] = x return True def same(self, x, y): return self.find(x) == self.find(y) def size(self, x): return -self.par[self.find(x)] def roots(self): return [i for i in range(self.N) if self.par[i] < 0] class LCA: def __init__(self, N, E): self.N = N self.E = E self.K = N.bit_length() self.par = [[-1 for _ in range(N)] for _ in range(self.K)] self.depth = [None] * N def _dfs(self, root): self.depth[root] = 0 stack = [(root, -1)] while stack: v, p = stack.pop() if not p < 0: self.par[0][v] = p self.depth[v] = self.depth[p] + 1 for _, dest in self.E[v]: if dest == p: continue stack.append((dest, v)) def calculate(self): for k in range(self.K - 1): for i in range(self.N): if self.par[k][i] < 0: continue self.par[k + 1][i] = self.par[k][self.par[k][i]] def la(self, v, x): for k in range(self.K): if x & (1 << k): v = self.par[k][v] return v def lca(self, u, v): if self.depth[u] > self.depth[v]: u, v = v, u d = self.depth[v] - self.depth[u] v = self.la(v, d) if u == v: return u for k in range(self.K)[::-1]: if self.par[k][u] != self.par[k][v]: u = self.par[k][u] v = self.par[k][v] return self.par[0][v] def dist(self, u, v): return self.depth[u] + self.depth[v] - 2 * self.depth[self.lca(u, v)] def jump(self, u, v, x): lca = self.lca(u, v) d1 = self.depth[u] - self.depth[lca] d2 = self.depth[v] - self.depth[lca] if d1 + d2 < x: return -1 if x <= d1: return self.la(u, x) return self.la(v, d1 + d2 - x) N, M, Q = readints() E = [[] for _ in range(N)] uf = UnionFind(N) D = N for _ in range(M): u, v = readints() u -= 1 v -= 1 E[u].append((1, v)) E[v].append((1, u)) if uf.unite(u, v) is True: D -= 1 query = [] todo = [] cnt = [0] * N for _ in range(Q): a, b = readints() a -= 1 b -= 1 if uf.same(a, b) is False: cnt[a] += 1 cnt[b] += 1 else: todo.append((a, b)) def f(a, v, ch, cost): return (a[0] + a[1], a[1]) def g(a, v): return (a[0], a[1] + cnt[v]) def merge(a, b): return (a[0] + b[0], a[1] + b[1]) gidx = {} for i, v in enumerate(uf.roots()): gidx[v] = i V = [[] for _ in range(D)] for i in range(N): V[gidx[uf.find(i)]].append(i) solver = Rerooting(N, E, f, g, merge, (0, 0)) lca = LCA(N, E) INF = 1 << 30 cost = [INF] * D for i in range(D): solver.calculate(root=V[i][0]) lca._dfs(V[i][0]) for v in V[i]: cost[i] = min(cost[i], solver.solve(v)[0]) lca.calculate() ans = sum(cost) for s, t in todo: ans += lca.dist(s, t) print(ans)