結果
問題 | No.922 東北きりきざむたん |
ユーザー | terasa |
提出日時 | 2022-11-03 23:18:14 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,664 ms / 2,000 ms |
コード長 | 6,847 bytes |
コンパイル時間 | 577 ms |
コンパイル使用メモリ | 86,992 KB |
実行使用メモリ | 150,112 KB |
最終ジャッジ日時 | 2023-09-25 06:24:08 |
合計ジャッジ時間 | 26,208 ms |
ジャッジサーバーID (参考情報) |
judge14 / judge11 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 159 ms
80,436 KB |
testcase_01 | AC | 162 ms
80,408 KB |
testcase_02 | AC | 161 ms
80,444 KB |
testcase_03 | AC | 164 ms
80,364 KB |
testcase_04 | AC | 183 ms
82,364 KB |
testcase_05 | AC | 174 ms
81,664 KB |
testcase_06 | AC | 197 ms
82,896 KB |
testcase_07 | AC | 187 ms
82,444 KB |
testcase_08 | AC | 198 ms
82,420 KB |
testcase_09 | AC | 862 ms
122,652 KB |
testcase_10 | AC | 889 ms
103,832 KB |
testcase_11 | AC | 852 ms
115,356 KB |
testcase_12 | AC | 504 ms
122,268 KB |
testcase_13 | AC | 434 ms
96,244 KB |
testcase_14 | AC | 960 ms
142,016 KB |
testcase_15 | AC | 377 ms
123,788 KB |
testcase_16 | AC | 1,444 ms
150,112 KB |
testcase_17 | AC | 1,443 ms
147,456 KB |
testcase_18 | AC | 1,477 ms
149,356 KB |
testcase_19 | AC | 1,418 ms
147,536 KB |
testcase_20 | AC | 1,419 ms
147,608 KB |
testcase_21 | AC | 1,642 ms
148,296 KB |
testcase_22 | AC | 1,664 ms
149,788 KB |
testcase_23 | AC | 1,295 ms
144,136 KB |
testcase_24 | AC | 1,226 ms
144,896 KB |
testcase_25 | AC | 960 ms
143,372 KB |
testcase_26 | AC | 953 ms
144,128 KB |
testcase_27 | AC | 927 ms
142,444 KB |
testcase_28 | AC | 456 ms
136,676 KB |
testcase_29 | AC | 909 ms
145,068 KB |
ソースコード
from typing import List, Tuple, Callable, TypeVar from typing import List, Tuple, Optional import sys import itertools import heapq import bisect from collections import deque, defaultdict from functools import lru_cache, cmp_to_key input = sys.stdin.readline # for AtCoder Easy test if __file__ != 'prog.py': sys.setrecursionlimit(10 ** 6) def readints(): return map(int, input().split()) def readlist(): return list(readints()) def readstr(): return input().rstrip() T = TypeVar('T') class Rerooting: # reference: https://null-mn.hatenablog.com/entry/2020/04/14/124151 # 適当な頂点vを根とする部分木に対して計算される値dp_vが、vの子c1, c2, ... ckを用いて # 下記のように表すことができる # dp_v = g(merge(f(dp_c1,c1), f(dp_c2,c2), ..., f(dp_ck,ck)), v) def __init__(self, N: int, E: List[Tuple[int, int]], f: Callable[[T, int, int, int], T], g: Callable[[T, int], T], merge: Callable[[T, T], T], e: T): self.N = N self.E = E self.f = f self.g = g self.merge = merge self.e = e self.ret = [self.e] * self.N self.dp = [[self.e for _ in range(len(self.E[v]))] for v in range(self.N)] def _dfs1(self, root): stack = [(root, -1)] while stack: v, p = stack.pop() if v < 0: v = ~v acc = self.e for i, (c, d) in enumerate(self.E[v]): if d == p: continue self.dp[v][i] = self.ret[d] acc = self.merge(acc, self.f(self.ret[d], v, d, c)) self.ret[v] = self.g(acc, v) continue stack.append((~v, p)) for i, (c, d) in enumerate(self.E[v]): if d == p: continue stack.append((d, v)) def _dfs2(self, root): stack = [(root, -1, self.e)] while stack: v, p, from_par = stack.pop() for i, (c, d) in enumerate(self.E[v]): if d == p: self.dp[v][i] = from_par break ch = len(self.E[v]) Sr = [self.e] * (ch + 1) for i in range(ch, 0, -1): c, d = self.E[v][i - 1] Sr[i - 1] = self.merge(Sr[i], self.f(self.dp[v][i - 1], v, d, c)) Sl = self.e for i, (c, d) in enumerate(self.E[v]): if d != p: val = self.merge(Sl, Sr[i + 1]) stack.append((d, v, self.g(val, v))) Sl = self.merge(Sl, self.f(self.dp[v][i], v, d, c)) def calculate(self, root=0): self._dfs1(root) self._dfs2(root) def solve(self, v): ans = self.e for i, (c, d) in enumerate(self.E[v]): ans = self.merge(ans, self.f(self.dp[v][i], v, d, c)) return self.g(ans, v) class UnionFind: def __init__(self, N): self.N = N self.par = [-1] * N def find(self, x): if self.par[x] < 0: return x else: self.par[x] = self.find(self.par[x]) return self.par[x] def unite(self, x, y): x = self.find(x) y = self.find(y) if x == y: return False if self.par[x] > self.par[y]: x, y = y, x self.par[x] += self.par[y] self.par[y] = x return True def same(self, x, y): return self.find(x) == self.find(y) def size(self, x): return -self.par[self.find(x)] def roots(self): return [i for i in range(self.N) if self.par[i] < 0] class LCA: def __init__(self, N, E): self.N = N self.E = E self.K = N.bit_length() self.par = [[-1 for _ in range(N)] for _ in range(self.K)] self.depth = [None] * N def _dfs(self, root): self.depth[root] = 0 stack = [(root, -1)] while stack: v, p = stack.pop() if not p < 0: self.par[0][v] = p self.depth[v] = self.depth[p] + 1 for _, dest in self.E[v]: if dest == p: continue stack.append((dest, v)) def calculate(self): for k in range(self.K - 1): for i in range(self.N): if self.par[k][i] < 0: continue self.par[k + 1][i] = self.par[k][self.par[k][i]] def la(self, v, x): for k in range(self.K): if x & (1 << k): v = self.par[k][v] return v def lca(self, u, v): if self.depth[u] > self.depth[v]: u, v = v, u d = self.depth[v] - self.depth[u] v = self.la(v, d) if u == v: return u for k in range(self.K)[::-1]: if self.par[k][u] != self.par[k][v]: u = self.par[k][u] v = self.par[k][v] return self.par[0][v] def dist(self, u, v): return self.depth[u] + self.depth[v] - 2 * self.depth[self.lca(u, v)] def jump(self, u, v, x): lca = self.lca(u, v) d1 = self.depth[u] - self.depth[lca] d2 = self.depth[v] - self.depth[lca] if d1 + d2 < x: return -1 if x <= d1: return self.la(u, x) return self.la(v, d1 + d2 - x) N, M, Q = readints() E = [[] for _ in range(N)] uf = UnionFind(N) D = N for _ in range(M): u, v = readints() u -= 1 v -= 1 E[u].append((1, v)) E[v].append((1, u)) if uf.unite(u, v) is True: D -= 1 query = [] cnt = [0] * N for _ in range(Q): a, b = readints() a -= 1 b -= 1 if uf.same(a, b) is False: cnt[a] += 1 cnt[b] += 1 query.append((a, b)) def f(a, v, ch, cost): return (a[0] + a[1], a[1]) def g(a, v): return (a[0], a[1] + cnt[v]) def merge(a, b): return (a[0] + b[0], a[1] + b[1]) gidx = {} for i, v in enumerate(uf.roots()): gidx[v] = i V = [[] for _ in range(D)] for i in range(N): V[gidx[uf.find(i)]].append(i) solver = Rerooting(N, E, f, g, merge, (0, 0)) lca = LCA(N, E) airports = [None] * D cost = [None] * N for i in range(D): solver.calculate(root=V[i][0]) lca._dfs(V[i][0]) for v in V[i]: cost[v] = solver.solve(v)[0] if airports[i] is None: airports[i] = v elif cost[airports[i]] > cost[v]: airports[i] = v A = [None] * N for i in range(D): for v in V[i]: A[v] = airports[i] ans = 0 lca.calculate() for s, t in query: sa, ta = A[s], A[t] if A[s] == A[t]: ans += lca.dist(s, t) else: ans += lca.dist(s, sa) + lca.dist(ta, t) print(ans)