結果

問題 No.922 東北きりきざむたん
コンテスト
ユーザー terasa
提出日時 2022-11-03 23:10:33
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 6,826 bytes
記録
コンパイル時間 168 ms
コンパイル使用メモリ 82,520 KB
実行使用メモリ 269,440 KB
最終ジャッジ日時 2024-07-18 05:18:42
合計ジャッジ時間 5,898 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 5 TLE * 1 -- * 20
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

from typing import List, Tuple, Callable, TypeVar
from typing import List, Tuple, Optional
import sys
import itertools
import heapq
import bisect
from collections import deque, defaultdict
from functools import lru_cache, cmp_to_key

input = sys.stdin.readline

# for AtCoder Easy test
if __file__ != 'prog.py':
    sys.setrecursionlimit(10 ** 6)


def readints(): return map(int, input().split())
def readlist(): return list(readints())
def readstr(): return input().rstrip()


T = TypeVar('T')


class Rerooting:
    # reference: https://null-mn.hatenablog.com/entry/2020/04/14/124151
    # 適当な頂点vを根とする部分木に対して計算される値dp_vが、vの子c1, c2, ... ckを用いて
    # 下記のように表すことができる
    # dp_v = g(merge(f(dp_c1,c1), f(dp_c2,c2), ..., f(dp_ck,ck)), v)
    def __init__(self, N: int, E: List[Tuple[int, int]],
                 f: Callable[[T, int, int, int], T],
                 g: Callable[[T, int], T],
                 merge: Callable[[T, T], T], e: T):
        self.N = N
        self.E = E
        self.f = f
        self.g = g
        self.merge = merge
        self.e = e

        self.dp = [[self.e for _ in range(len(self.E[v]))] for v in range(self.N)]

    def _dfs1(self, root):
        stack = [(root, -1)]
        ret = [self.e] * self.N
        while stack:
            v, p = stack.pop()
            if v < 0:
                v = ~v
                acc = self.e
                for i, (c, d) in enumerate(self.E[v]):
                    if d == p:
                        continue
                    self.dp[v][i] = ret[d]
                    acc = self.merge(acc, self.f(ret[d], v, d, c))
                ret[v] = self.g(acc, v)
                continue

            stack.append((~v, p))
            for i, (c, d) in enumerate(self.E[v]):
                if d == p:
                    continue
                stack.append((d, v))

    def _dfs2(self, root):
        stack = [(root, -1, self.e)]
        while stack:
            v, p, from_par = stack.pop()
            for i, (c, d) in enumerate(self.E[v]):
                if d == p:
                    self.dp[v][i] = from_par
                    break
            ch = len(self.E[v])
            Sr = [self.e] * (ch + 1)
            for i in range(ch, 0, -1):
                c, d = self.E[v][i - 1]
                Sr[i - 1] = self.merge(Sr[i], self.f(self.dp[v][i - 1], v, d, c))
            Sl = self.e
            for i, (c, d) in enumerate(self.E[v]):
                if d != p:
                    val = self.merge(Sl, Sr[i + 1])
                    stack.append((d, v, self.g(val, v)))
                Sl = self.merge(Sl, self.f(self.dp[v][i], v, d, c))

    def calculate(self, root=0):
        self._dfs1(root)
        self._dfs2(root)

    def solve(self, v):
        ans = self.e
        for i, (c, d) in enumerate(self.E[v]):
            ans = self.merge(ans, self.f(self.dp[v][i], v, d, c))
        return self.g(ans, v)


class UnionFind:
    def __init__(self, N):
        self.N = N
        self.par = [-1] * N

    def find(self, x):
        if self.par[x] < 0:
            return x
        else:
            self.par[x] = self.find(self.par[x])
            return self.par[x]

    def unite(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return False
        if self.par[x] > self.par[y]:
            x, y = y, x

        self.par[x] += self.par[y]
        self.par[y] = x
        return True

    def same(self, x, y):
        return self.find(x) == self.find(y)

    def size(self, x):
        return -self.par[self.find(x)]

    def roots(self):
        return [i for i in range(self.N) if self.par[i] < 0]


class LCA:
    def __init__(self, N, E):
        self.N = N
        self.E = E
        self.K = N.bit_length()
        self.par = [[-1 for _ in range(N)] for _ in range(self.K)]
        self.depth = [None] * N

    def _dfs(self, root):
        self.depth[root] = 0
        stack = [(root, -1)]
        while stack:
            v, p = stack.pop()
            if not p < 0:
                self.par[0][v] = p
                self.depth[v] = self.depth[p] + 1

            for _, dest in self.E[v]:
                if dest == p:
                    continue
                stack.append((dest, v))

    def calculate(self):
        for k in range(self.K - 1):
            for i in range(self.N):
                if self.par[k][i] < 0:
                    continue
                self.par[k + 1][i] = self.par[k][self.par[k][i]]

    def la(self, v, x):
        for k in range(self.K):
            if x & (1 << k):
                v = self.par[k][v]
        return v

    def lca(self, u, v):
        if self.depth[u] > self.depth[v]:
            u, v = v, u
        d = self.depth[v] - self.depth[u]
        v = self.la(v, d)

        if u == v:
            return u
        for k in range(self.K)[::-1]:
            if self.par[k][u] != self.par[k][v]:
                u = self.par[k][u]
                v = self.par[k][v]
        return self.par[0][v]

    def dist(self, u, v):
        return self.depth[u] + self.depth[v] - 2 * self.depth[self.lca(u, v)]

    def jump(self, u, v, x):
        lca = self.lca(u, v)
        d1 = self.depth[u] - self.depth[lca]
        d2 = self.depth[v] - self.depth[lca]

        if d1 + d2 < x:
            return -1
        if x <= d1:
            return self.la(u, x)
        return self.la(v, d1 + d2 - x)


N, M, Q = readints()
E = [[] for _ in range(N)]
uf = UnionFind(N)
D = N
for _ in range(M):
    u, v = readints()
    u -= 1
    v -= 1
    E[u].append((1, v))
    E[v].append((1, u))
    if uf.unite(u, v) is True:
        D -= 1

query = []
cnt = [0] * N
for _ in range(Q):
    a, b = readints()
    a -= 1
    b -= 1
    if uf.same(a, b) is False:
        cnt[a] += 1
        cnt[b] += 1
    query.append((a, b))


def f(a, v, ch, cost):
    return (a[0] + a[1], a[1])


def g(a, v):
    return (a[0], a[1] + cnt[v])


def merge(a, b):
    return (a[0] + b[0], a[1] + b[1])


gidx = {}
for i, v in enumerate(uf.roots()):
    gidx[v] = i
V = [[] for _ in range(D)]
for i in range(N):
    V[gidx[uf.find(i)]].append(i)

solver = Rerooting(N, E, f, g, merge, (0, 0))
lca = LCA(N, E)
airports = [None] * D
cost = [None] * N
for i in range(D):
    solver.calculate(root=V[i][0])
    lca._dfs(V[i][0])
    for v in V[i]:
        cost[v] = solver.solve(v)[0]
        if airports[i] is None:
            airports[i] = v
        elif cost[airports[i]] > cost[v]:
            airports[i] = v
A = [None] * N
for i in range(D):
    for v in V[i]:
        A[v] = airports[i]

ans = 0
lca.calculate()
for s, t in query:
    sa, ta = A[s], A[t]
    if A[s] == A[t]:
        ans += lca.dist(s, t)
    else:
        ans += lca.dist(s, sa) + lca.dist(ta, t)
print(ans)
0