結果

問題 No.1194 Replace
ユーザー lam6er
提出日時 2025-04-15 22:19:39
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,991 bytes
コンパイル時間 278 ms
コンパイル使用メモリ 81,788 KB
実行使用メモリ 348,628 KB
最終ジャッジ日時 2025-04-15 22:21:52
合計ジャッジ時間 22,789 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other WA * 27
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import defaultdict

def main():
    sys.setrecursionlimit(1 << 25)
    N, M = map(int, sys.stdin.readline().split())
    operations = [tuple(map(int, sys.stdin.readline().split())) for _ in range(M)]
    
    # Collect all unique B values
    unique_B = set()
    B_list = []
    graph = defaultdict(list)
    for B, C in operations:
        unique_B.add(B)
        graph[B].append(C)
    
    # Tarjan's algorithm to find SCCs
    index = 0
    indices = {}
    low = {}
    on_stack = set()
    stack = []
    scc_list = []
    
    def strongconnect(v):
        nonlocal index
        indices[v] = index
        low[v] = index
        index += 1
        stack.append(v)
        on_stack.add(v)
        for w in graph.get(v, []):
            if w not in indices:
                strongconnect(w)
                low[v] = min(low[v], low[w])
            elif w in on_stack:
                low[v] = min(low[v], indices[w])
        if low[v] == indices[v]:
            scc = []
            while True:
                w = stack.pop()
                on_stack.remove(w)
                scc.append(w)
                if w == v:
                    break
            scc_list.append(scc)
    
    for v in unique_B:
        if v not in indices:
            strongconnect(v)
    
    # Map each node to its SCC index
    node_to_scc = {}
    for i, scc in enumerate(scc_list):
        for node in scc:
            node_to_scc[node] = i
    
    # Compute max_val for each SCC
    scc_max_val = [max(scc) for scc in scc_list]
    
    # Build the DAG between SCCs
    scc_graph = defaultdict(set)
    for v in unique_B:
        for w in graph[v]:
            if w in unique_B:
                v_scc = node_to_scc[v]
                w_scc = node_to_scc[w]
                if v_scc != w_scc:
                    scc_graph[v_scc].add(w_scc)
    
    # Perform a topological sort using DFS
    visited = set()
    top_order = []
    
    def dfs(u):
        if u in visited:
            return
        visited.add(u)
        for v in scc_graph.get(u, []):
            dfs(v)
        top_order.append(u)
    
    for u in range(len(scc_list)):
        if u not in visited:
            dfs(u)
    
    # Process SCCs in topological order (sinks first)
    processing_order = top_order
    
    # Compute max_reachable for each SCC
    scc_max_reachable = [0] * len(scc_list)
    for u in processing_order:
        current_max = scc_max_val[u]
        for v in scc_graph.get(u, []):
            if scc_max_reachable[v] > current_max:
                current_max = scc_max_reachable[v]
        scc_max_reachable[u] = current_max
    
    # Calculate the sum difference
    sum_diff = 0
    for B in unique_B:
        scc_id = node_to_scc[B]
        max_r = scc_max_reachable[scc_id]
        sum_diff += (max_r - B)
    
    # Initial sum is sum(1..N)
    initial_sum = N * (N + 1) // 2
    total = initial_sum + sum_diff
    print(total)

if __name__ == "__main__":
    main()
0