結果
問題 | No.922 東北きりきざむたん |
ユーザー | terasa |
提出日時 | 2022-11-03 23:18:14 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,399 ms / 2,000 ms |
コード長 | 6,847 bytes |
コンパイル時間 | 145 ms |
コンパイル使用メモリ | 82,468 KB |
実行使用メモリ | 142,580 KB |
最終ジャッジ日時 | 2024-07-18 05:23:40 |
合計ジャッジ時間 | 20,408 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 70 ms
71,100 KB |
testcase_01 | AC | 61 ms
69,988 KB |
testcase_02 | AC | 62 ms
70,576 KB |
testcase_03 | AC | 60 ms
69,900 KB |
testcase_04 | AC | 83 ms
78,124 KB |
testcase_05 | AC | 69 ms
74,156 KB |
testcase_06 | AC | 97 ms
78,660 KB |
testcase_07 | AC | 90 ms
78,496 KB |
testcase_08 | AC | 109 ms
78,368 KB |
testcase_09 | AC | 675 ms
116,540 KB |
testcase_10 | AC | 706 ms
97,064 KB |
testcase_11 | AC | 667 ms
110,320 KB |
testcase_12 | AC | 389 ms
118,240 KB |
testcase_13 | AC | 304 ms
91,612 KB |
testcase_14 | AC | 775 ms
137,204 KB |
testcase_15 | AC | 258 ms
119,180 KB |
testcase_16 | AC | 1,221 ms
141,060 KB |
testcase_17 | AC | 1,168 ms
140,972 KB |
testcase_18 | AC | 1,220 ms
142,580 KB |
testcase_19 | AC | 1,173 ms
141,516 KB |
testcase_20 | AC | 1,148 ms
141,768 KB |
testcase_21 | AC | 1,399 ms
142,132 KB |
testcase_22 | AC | 1,350 ms
141,808 KB |
testcase_23 | AC | 1,044 ms
138,336 KB |
testcase_24 | AC | 982 ms
138,628 KB |
testcase_25 | AC | 811 ms
138,396 KB |
testcase_26 | AC | 797 ms
137,796 KB |
testcase_27 | AC | 751 ms
137,844 KB |
testcase_28 | AC | 336 ms
131,972 KB |
testcase_29 | AC | 708 ms
138,480 KB |
ソースコード
from typing import List, Tuple, Callable, TypeVar from typing import List, Tuple, Optional import sys import itertools import heapq import bisect from collections import deque, defaultdict from functools import lru_cache, cmp_to_key input = sys.stdin.readline # for AtCoder Easy test if __file__ != 'prog.py': sys.setrecursionlimit(10 ** 6) def readints(): return map(int, input().split()) def readlist(): return list(readints()) def readstr(): return input().rstrip() T = TypeVar('T') class Rerooting: # reference: https://null-mn.hatenablog.com/entry/2020/04/14/124151 # 適当な頂点vを根とする部分木に対して計算される値dp_vが、vの子c1, c2, ... ckを用いて # 下記のように表すことができる # dp_v = g(merge(f(dp_c1,c1), f(dp_c2,c2), ..., f(dp_ck,ck)), v) def __init__(self, N: int, E: List[Tuple[int, int]], f: Callable[[T, int, int, int], T], g: Callable[[T, int], T], merge: Callable[[T, T], T], e: T): self.N = N self.E = E self.f = f self.g = g self.merge = merge self.e = e self.ret = [self.e] * self.N self.dp = [[self.e for _ in range(len(self.E[v]))] for v in range(self.N)] def _dfs1(self, root): stack = [(root, -1)] while stack: v, p = stack.pop() if v < 0: v = ~v acc = self.e for i, (c, d) in enumerate(self.E[v]): if d == p: continue self.dp[v][i] = self.ret[d] acc = self.merge(acc, self.f(self.ret[d], v, d, c)) self.ret[v] = self.g(acc, v) continue stack.append((~v, p)) for i, (c, d) in enumerate(self.E[v]): if d == p: continue stack.append((d, v)) def _dfs2(self, root): stack = [(root, -1, self.e)] while stack: v, p, from_par = stack.pop() for i, (c, d) in enumerate(self.E[v]): if d == p: self.dp[v][i] = from_par break ch = len(self.E[v]) Sr = [self.e] * (ch + 1) for i in range(ch, 0, -1): c, d = self.E[v][i - 1] Sr[i - 1] = self.merge(Sr[i], self.f(self.dp[v][i - 1], v, d, c)) Sl = self.e for i, (c, d) in enumerate(self.E[v]): if d != p: val = self.merge(Sl, Sr[i + 1]) stack.append((d, v, self.g(val, v))) Sl = self.merge(Sl, self.f(self.dp[v][i], v, d, c)) def calculate(self, root=0): self._dfs1(root) self._dfs2(root) def solve(self, v): ans = self.e for i, (c, d) in enumerate(self.E[v]): ans = self.merge(ans, self.f(self.dp[v][i], v, d, c)) return self.g(ans, v) class UnionFind: def __init__(self, N): self.N = N self.par = [-1] * N def find(self, x): if self.par[x] < 0: return x else: self.par[x] = self.find(self.par[x]) return self.par[x] def unite(self, x, y): x = self.find(x) y = self.find(y) if x == y: return False if self.par[x] > self.par[y]: x, y = y, x self.par[x] += self.par[y] self.par[y] = x return True def same(self, x, y): return self.find(x) == self.find(y) def size(self, x): return -self.par[self.find(x)] def roots(self): return [i for i in range(self.N) if self.par[i] < 0] class LCA: def __init__(self, N, E): self.N = N self.E = E self.K = N.bit_length() self.par = [[-1 for _ in range(N)] for _ in range(self.K)] self.depth = [None] * N def _dfs(self, root): self.depth[root] = 0 stack = [(root, -1)] while stack: v, p = stack.pop() if not p < 0: self.par[0][v] = p self.depth[v] = self.depth[p] + 1 for _, dest in self.E[v]: if dest == p: continue stack.append((dest, v)) def calculate(self): for k in range(self.K - 1): for i in range(self.N): if self.par[k][i] < 0: continue self.par[k + 1][i] = self.par[k][self.par[k][i]] def la(self, v, x): for k in range(self.K): if x & (1 << k): v = self.par[k][v] return v def lca(self, u, v): if self.depth[u] > self.depth[v]: u, v = v, u d = self.depth[v] - self.depth[u] v = self.la(v, d) if u == v: return u for k in range(self.K)[::-1]: if self.par[k][u] != self.par[k][v]: u = self.par[k][u] v = self.par[k][v] return self.par[0][v] def dist(self, u, v): return self.depth[u] + self.depth[v] - 2 * self.depth[self.lca(u, v)] def jump(self, u, v, x): lca = self.lca(u, v) d1 = self.depth[u] - self.depth[lca] d2 = self.depth[v] - self.depth[lca] if d1 + d2 < x: return -1 if x <= d1: return self.la(u, x) return self.la(v, d1 + d2 - x) N, M, Q = readints() E = [[] for _ in range(N)] uf = UnionFind(N) D = N for _ in range(M): u, v = readints() u -= 1 v -= 1 E[u].append((1, v)) E[v].append((1, u)) if uf.unite(u, v) is True: D -= 1 query = [] cnt = [0] * N for _ in range(Q): a, b = readints() a -= 1 b -= 1 if uf.same(a, b) is False: cnt[a] += 1 cnt[b] += 1 query.append((a, b)) def f(a, v, ch, cost): return (a[0] + a[1], a[1]) def g(a, v): return (a[0], a[1] + cnt[v]) def merge(a, b): return (a[0] + b[0], a[1] + b[1]) gidx = {} for i, v in enumerate(uf.roots()): gidx[v] = i V = [[] for _ in range(D)] for i in range(N): V[gidx[uf.find(i)]].append(i) solver = Rerooting(N, E, f, g, merge, (0, 0)) lca = LCA(N, E) airports = [None] * D cost = [None] * N for i in range(D): solver.calculate(root=V[i][0]) lca._dfs(V[i][0]) for v in V[i]: cost[v] = solver.solve(v)[0] if airports[i] is None: airports[i] = v elif cost[airports[i]] > cost[v]: airports[i] = v A = [None] * N for i in range(D): for v in V[i]: A[v] = airports[i] ans = 0 lca.calculate() for s, t in query: sa, ta = A[s], A[t] if A[s] == A[t]: ans += lca.dist(s, t) else: ans += lca.dist(s, sa) + lca.dist(ta, t) print(ans)