結果
| 問題 | 
                            No.1194 Replace
                             | 
                    
| コンテスト | |
| ユーザー | 
                             lam6er
                         | 
                    
| 提出日時 | 2025-04-16 15:36:08 | 
| 言語 | PyPy3  (7.3.15)  | 
                    
| 結果 | 
                             
                                WA
                                 
                             
                            
                         | 
                    
| 実行時間 | - | 
| コード長 | 3,334 bytes | 
| コンパイル時間 | 135 ms | 
| コンパイル使用メモリ | 82,460 KB | 
| 実行使用メモリ | 216,060 KB | 
| 最終ジャッジ日時 | 2025-04-16 15:41:16 | 
| 合計ジャッジ時間 | 12,033 ms | 
| 
                            ジャッジサーバーID (参考情報)  | 
                        judge1 / judge3 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 3 | 
| other | AC * 7 WA * 20 | 
ソースコード
import sys
from collections import defaultdict, deque
def main():
    sys.setrecursionlimit(1 << 25)
    N, M = map(int, sys.stdin.readline().split())
    replace_map = {}
    for _ in range(M):
        B, C = map(int, sys.stdin.readline().split())
        if B not in replace_map or C > replace_map[B]:
            replace_map[B] = C
    if not replace_map:
        print(N * (N + 1) // 2)
        return
    nodes = list(replace_map.keys())
    adj = defaultdict(list)
    for B in nodes:
        C = replace_map[B]
        if C in replace_map:
            adj[B].append(C)
    visited = set()
    order = []
    def dfs(u):
        stack = [(u, False)]
        while stack:
            node, processed = stack.pop()
            if processed:
                order.append(node)
                continue
            if node in visited:
                continue
            visited.add(node)
            stack.append((node, True))
            for v in adj.get(node, []):
                if v not in visited:
                    stack.append((v, False))
    for node in nodes:
        if node not in visited:
            dfs(node)
    reverse_adj = defaultdict(list)
    for u in adj:
        for v in adj[u]:
            reverse_adj[v].append(u)
    visited = set()
    sccs = []
    for node in reversed(order):
        if node not in visited:
            stack = [node]
            visited.add(node)
            scc = []
            while stack:
                u = stack.pop()
                scc.append(u)
                for v in reverse_adj.get(u, []):
                    if v not in visited:
                        visited.add(v)
                        stack.append(v)
            sccs.append(scc)
    scc_internal_max = []
    for scc in sccs:
        current_max = 0
        for node in scc:
            current = max(node, replace_map[node])
            if current > current_max:
                current_max = current
        scc_internal_max.append(current_max)
    scc_id = {}
    for i, scc in enumerate(sccs):
        for node in scc:
            scc_id[node] = i
    compressed_adj = defaultdict(set)
    for u in adj:
        for v in adj[u]:
            if scc_id[u] != scc_id[v]:
                compressed_adj[scc_id[u]].add(scc_id[v])
    in_degree = defaultdict(int)
    for u in compressed_adj:
        for v in compressed_adj[u]:
            in_degree[v] += 1
    queue = deque()
    for i in range(len(sccs)):
        if in_degree.get(i, 0) == 0:
            queue.append(i)
    top_order = []
    while queue:
        u = queue.popleft()
        top_order.append(u)
        for v in compressed_adj.get(u, []):
            in_degree[v] -= 1
            if in_degree[v] == 0:
                queue.append(v)
    top_order_reversed = top_order[::-1]
    scc_max = [0] * len(sccs)
    for i in top_order_reversed:
        current_max = scc_internal_max[i]
        for neighbor in compressed_adj.get(i, []):
            if scc_max[neighbor] > current_max:
                current_max = scc_max[neighbor]
        scc_max[i] = current_max
    original_sum = N * (N + 1) // 2
    sum_B = sum(replace_map.keys())
    sum_max_B = 0
    for B in replace_map:
        sum_max_B += scc_max[scc_id[B]]
    result = original_sum - sum_B + sum_max_B
    print(result)
if __name__ == "__main__":
    main()
            
            
            
        
            
lam6er