結果

問題 No.19 ステージの選択
ユーザー しらっ亭しらっ亭
提出日時 2015-06-23 01:37:56
言語 Python3
(3.12.2 + numpy 1.26.4 + scipy 1.12.0)
結果
WA  
実行時間 -
コード長 1,527 bytes
コンパイル時間 90 ms
コンパイル使用メモリ 10,912 KB
実行使用メモリ 8,760 KB
最終ジャッジ日時 2023-09-22 00:05:52
合計ジャッジ時間 1,952 ms
ジャッジサーバーID
(参考情報)
judge12 / judge15
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 21 ms
8,612 KB
testcase_01 AC 21 ms
8,684 KB
testcase_02 AC 20 ms
8,752 KB
testcase_03 AC 19 ms
8,680 KB
testcase_04 AC 19 ms
8,712 KB
testcase_05 AC 20 ms
8,712 KB
testcase_06 AC 19 ms
8,684 KB
testcase_07 AC 20 ms
8,572 KB
testcase_08 AC 20 ms
8,760 KB
testcase_09 AC 20 ms
8,748 KB
testcase_10 AC 20 ms
8,756 KB
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 AC 20 ms
8,688 KB
testcase_15 AC 20 ms
8,736 KB
testcase_16 AC 19 ms
8,748 KB
testcase_17 AC 20 ms
8,616 KB
testcase_18 AC 20 ms
8,580 KB
testcase_19 AC 19 ms
8,684 KB
testcase_20 AC 20 ms
8,736 KB
testcase_21 AC 20 ms
8,596 KB
testcase_22 AC 20 ms
8,700 KB
testcase_23 AC 19 ms
8,576 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

from collections import defaultdict


def solve(N, L, S):
    # union-find で分ける N が最大100なので多分大丈夫
    rangen = list(range(N))
    par = list(rangen)

    def find(x):
        if par[x] == x:
            return x
        else:
            par[x] = find(par[x])
            return par[x]

    def union(x, y):
        x = find(x)
        y = find(y)
        if x == y:
            return
        par[x] = y

    for i in rangen:
        union(i, S[i])

    G = defaultdict(list)
    for i in rangen:
        g = find(i)
        G[g].append(i)

    total = 0
    # グループそれぞれの中で
    for ns in G.values():
        # 閉路を検出
        n = ns[0]
        visited = set()
        while n not in visited:
            visited.add(n)
            n = S[n]
        loop_start = n

        # 閉路の中で最低難度のステージを探す
        min_stage = n
        while True:
            if L[n] < L[min_stage]:
                min_stage = n
            n = S[n]
            if n == loop_start:
                break

        # ループ中最低難度のステージ以外は 1/2 して足す
        for n in ns:
            if n == min_stage:
                total += L[n]
            else:
                total += L[n] / 2
    return total


def main():
    N = int(input())
    L = []
    S = []
    for i in range(N):
        l, s = input().split()
        L.append(int(l))
        S.append(int(s) - 1)
    print(solve(N, L, S))


if __name__ == '__main__':
    main()
0