結果

問題 No.1194 Replace
ユーザー lam6er
提出日時 2025-04-16 15:48:10
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,546 ms / 2,000 ms
コード長 3,565 bytes
コンパイル時間 632 ms
コンパイル使用メモリ 82,256 KB
実行使用メモリ 445,248 KB
最終ジャッジ日時 2025-04-16 15:49:27
合計ジャッジ時間 29,637 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 27
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin
from collections import defaultdict, deque

sys.setrecursionlimit(1 << 25)

def main():
    N, M = map(int, stdin.readline().split())
    operations = [tuple(map(int, stdin.readline().split())) for _ in range(M)]
    
    # Collect all nodes involved in operations
    S = set()
    edges = defaultdict(list)
    for B, C in operations:
        S.add(B)
        S.add(C)
        edges[B].append(C)
    
    S = list(S)
    node_id = {x: i for i, x in enumerate(S)}
    id_node = {i: x for i, x in enumerate(S)}
    n = len(S)
    
    # Build the adjacency list for Tarjan's algorithm
    adj = [[] for _ in range(n)]
    for B, C in operations:
        u = node_id[B]
        if C not in node_id:
            continue
        v = node_id[C]
        adj[u].append(v)
    
    # Tarjan's algorithm to find SCCs
    index = 0
    indices = [None] * n
    low = [None] * n
    on_stack = [False] * n
    stack = []
    sccs = []
    
    def strongconnect(v):
        nonlocal index
        indices[v] = index
        low[v] = index
        index += 1
        stack.append(v)
        on_stack[v] = True
        for w in adj[v]:
            if indices[w] is None:
                strongconnect(w)
                low[v] = min(low[v], low[w])
            elif on_stack[w]:
                low[v] = min(low[v], indices[w])
        if low[v] == indices[v]:
            scc = []
            while True:
                w = stack.pop()
                on_stack[w] = False
                scc.append(w)
                if w == v:
                    break
            sccs.append(scc)
    
    for v in range(n):
        if indices[v] is None:
            strongconnect(v)
    
    # Compute max_value for each SCC
    scc_max = []
    scc_map = {}
    for i, scc in enumerate(sccs):
        max_val = max(id_node[node] for node in scc)
        scc_max.append(max_val)
        for node in scc:
            scc_map[node] = i
    
    # Build DAG between SCCs
    dag_adj = defaultdict(set)
    in_degree = defaultdict(int)
    for i, scc in enumerate(sccs):
        for node in scc:
            for neighbor in adj[node]:
                neighbor_scc = scc_map[neighbor]
                if i != neighbor_scc and neighbor_scc not in dag_adj[i]:
                    dag_adj[i].add(neighbor_scc)
                    in_degree[neighbor_scc] += 1
    
    # Topological sort using Kahn's algorithm
    topo_order = []
    queue = deque()
    for i in range(len(sccs)):
        if in_degree.get(i, 0) == 0:
            queue.append(i)
    
    while queue:
        u = queue.popleft()
        topo_order.append(u)
        for v in dag_adj[u]:
            in_degree[v] -= 1
            if in_degree[v] == 0:
                queue.append(v)
    
    # Process in reverse topological order to compute max_reachable
    max_reachable = [0] * len(sccs)
    for i in reversed(topo_order):
        current_max = scc_max[i]
        for neighbor in dag_adj[i]:
            if max_reachable[neighbor] > current_max:
                current_max = max_reachable[neighbor]
        max_reachable[i] = current_max
    
    # Map each node to its max_reachable
    node_max = {}
    for i, scc in enumerate(sccs):
        mr = max_reachable[i]
        for node in scc:
            original_val = id_node[node]
            node_max[original_val] = mr
    
    # Compute sum_total
    sum_S = sum(S)
    sum_max = 0
    for x in S:
        sum_max += node_max[x]
    total = (N * (N + 1) // 2) - sum_S + sum_max
    print(total)

if __name__ == "__main__":
    main()
0