結果

問題 No.19 ステージの選択
ユーザー しらっ亭しらっ亭
提出日時 2015-06-23 01:40:18
言語 Python3
(3.12.2 + numpy 1.26.4 + scipy 1.12.0)
結果
AC  
実行時間 28 ms / 5,000 ms
コード長 1,544 bytes
コンパイル時間 70 ms
コンパイル使用メモリ 12,672 KB
実行使用メモリ 11,008 KB
最終ジャッジ日時 2024-06-02 12:29:20
合計ジャッジ時間 1,381 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 26 ms
10,752 KB
testcase_01 AC 27 ms
10,752 KB
testcase_02 AC 28 ms
10,752 KB
testcase_03 AC 27 ms
10,752 KB
testcase_04 AC 26 ms
10,880 KB
testcase_05 AC 25 ms
10,752 KB
testcase_06 AC 26 ms
10,752 KB
testcase_07 AC 26 ms
11,008 KB
testcase_08 AC 26 ms
10,752 KB
testcase_09 AC 26 ms
11,008 KB
testcase_10 AC 25 ms
10,752 KB
testcase_11 AC 25 ms
10,880 KB
testcase_12 AC 25 ms
11,008 KB
testcase_13 AC 25 ms
10,880 KB
testcase_14 AC 25 ms
10,880 KB
testcase_15 AC 25 ms
10,880 KB
testcase_16 AC 26 ms
10,752 KB
testcase_17 AC 25 ms
10,880 KB
testcase_18 AC 27 ms
10,752 KB
testcase_19 AC 25 ms
11,008 KB
testcase_20 AC 26 ms
10,880 KB
testcase_21 AC 25 ms
10,752 KB
testcase_22 AC 27 ms
11,008 KB
testcase_23 AC 25 ms
10,752 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

from collections import defaultdict


def solve(N, L, S):
    # union-find で分ける N が最大100なので多分大丈夫
    rangen = list(range(N))
    par = list(rangen)

    def find(x):
        if par[x] == x:
            return x
        else:
            par[x] = find(par[x])
            return par[x]

    def union(x, y):
        x = find(x)
        y = find(y)
        if x == y:
            return
        par[x] = y

    for i in rangen:
        union(i, S[i])

    G = defaultdict(list)
    for i in rangen:
        g = find(i)
        G[g].append(i)

    total = 0
    # グループそれぞれの中で
    for ns in G.values():
        # 閉路を検出
        n = ns[0]
        visited = set()
        while n not in visited:
            visited.add(n)
            n = S[n]
        loop_start = n

        # 閉路の中で最低難度のステージを探す
        min_stage = n
        while True:
            if L[n] < L[min_stage]:
                min_stage = n
            n = S[n]
            if n == loop_start:
                break

        # ループ中最低難度のステージ以外は 1/2 して足す
        for n in ns:
            if n == min_stage:
                total += L[n]
            else:
                total += L[n] / 2
    return '{:.1f}'.format(total)


def main():
    N = int(input())
    L = []
    S = []
    for i in range(N):
        l, s = input().split()
        L.append(int(l))
        S.append(int(s) - 1)
    print(solve(N, L, S))


if __name__ == '__main__':
    main()
0