結果

問題 No.1194 Replace
ユーザー lam6er
提出日時 2025-04-15 21:39:43
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,334 bytes
コンパイル時間 459 ms
コンパイル使用メモリ 82,228 KB
実行使用メモリ 216,380 KB
最終ジャッジ日時 2025-04-15 21:42:10
合計ジャッジ時間 13,036 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 7 WA * 20
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict, deque

def main():
    sys.setrecursionlimit(1 << 25)
    N, M = map(int, sys.stdin.readline().split())
    replace_map = {}
    for _ in range(M):
        B, C = map(int, sys.stdin.readline().split())
        if B not in replace_map or C > replace_map[B]:
            replace_map[B] = C

    if not replace_map:
        print(N * (N + 1) // 2)
        return

    nodes = list(replace_map.keys())
    adj = defaultdict(list)
    for B in nodes:
        C = replace_map[B]
        if C in replace_map:
            adj[B].append(C)

    visited = set()
    order = []
    def dfs(u):
        stack = [(u, False)]
        while stack:
            node, processed = stack.pop()
            if processed:
                order.append(node)
                continue
            if node in visited:
                continue
            visited.add(node)
            stack.append((node, True))
            for v in adj.get(node, []):
                if v not in visited:
                    stack.append((v, False))

    for node in nodes:
        if node not in visited:
            dfs(node)

    reverse_adj = defaultdict(list)
    for u in adj:
        for v in adj[u]:
            reverse_adj[v].append(u)

    visited = set()
    sccs = []
    for node in reversed(order):
        if node not in visited:
            stack = [node]
            visited.add(node)
            scc = []
            while stack:
                u = stack.pop()
                scc.append(u)
                for v in reverse_adj.get(u, []):
                    if v not in visited:
                        visited.add(v)
                        stack.append(v)
            sccs.append(scc)

    scc_internal_max = []
    for scc in sccs:
        current_max = 0
        for node in scc:
            current = max(node, replace_map[node])
            if current > current_max:
                current_max = current
        scc_internal_max.append(current_max)

    scc_id = {}
    for i, scc in enumerate(sccs):
        for node in scc:
            scc_id[node] = i

    compressed_adj = defaultdict(set)
    for u in adj:
        for v in adj[u]:
            if scc_id[u] != scc_id[v]:
                compressed_adj[scc_id[u]].add(scc_id[v])

    in_degree = defaultdict(int)
    for u in compressed_adj:
        for v in compressed_adj[u]:
            in_degree[v] += 1

    queue = deque()
    for i in range(len(sccs)):
        if in_degree.get(i, 0) == 0:
            queue.append(i)

    top_order = []
    while queue:
        u = queue.popleft()
        top_order.append(u)
        for v in compressed_adj.get(u, []):
            in_degree[v] -= 1
            if in_degree[v] == 0:
                queue.append(v)

    top_order_reversed = top_order[::-1]
    scc_max = [0] * len(sccs)
    for i in top_order_reversed:
        current_max = scc_internal_max[i]
        for neighbor in compressed_adj.get(i, []):
            if scc_max[neighbor] > current_max:
                current_max = scc_max[neighbor]
        scc_max[i] = current_max

    original_sum = N * (N + 1) // 2
    sum_B = sum(replace_map.keys())
    sum_max_B = 0
    for B in replace_map:
        sum_max_B += scc_max[scc_id[B]]

    result = original_sum - sum_B + sum_max_B
    print(result)

if __name__ == "__main__":
    main()
0