結果

問題 No.19 ステージの選択
ユーザー しらっ亭しらっ亭
提出日時 2015-06-23 01:37:56
言語 Python3
(3.13.1 + numpy 2.2.1 + scipy 1.14.1)
結果
WA  
実行時間 -
コード長 1,527 bytes
コンパイル時間 202 ms
コンパイル使用メモリ 12,672 KB
実行使用メモリ 11,136 KB
最終ジャッジ日時 2024-07-07 17:00:36
合計ジャッジ時間 1,984 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 33 ms
10,880 KB
testcase_01 AC 32 ms
10,880 KB
testcase_02 AC 30 ms
10,624 KB
testcase_03 AC 31 ms
10,752 KB
testcase_04 AC 31 ms
10,880 KB
testcase_05 AC 30 ms
10,752 KB
testcase_06 AC 31 ms
10,880 KB
testcase_07 AC 31 ms
10,624 KB
testcase_08 AC 31 ms
10,624 KB
testcase_09 AC 31 ms
10,880 KB
testcase_10 AC 31 ms
10,880 KB
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 AC 31 ms
10,880 KB
testcase_15 AC 31 ms
10,880 KB
testcase_16 AC 31 ms
10,752 KB
testcase_17 AC 31 ms
10,880 KB
testcase_18 AC 31 ms
10,880 KB
testcase_19 AC 31 ms
10,880 KB
testcase_20 AC 30 ms
10,880 KB
testcase_21 AC 31 ms
10,752 KB
testcase_22 AC 31 ms
10,624 KB
testcase_23 AC 31 ms
10,752 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

from collections import defaultdict


def solve(N, L, S):
    # union-find で分ける N が最大100なので多分大丈夫
    rangen = list(range(N))
    par = list(rangen)

    def find(x):
        if par[x] == x:
            return x
        else:
            par[x] = find(par[x])
            return par[x]

    def union(x, y):
        x = find(x)
        y = find(y)
        if x == y:
            return
        par[x] = y

    for i in rangen:
        union(i, S[i])

    G = defaultdict(list)
    for i in rangen:
        g = find(i)
        G[g].append(i)

    total = 0
    # グループそれぞれの中で
    for ns in G.values():
        # 閉路を検出
        n = ns[0]
        visited = set()
        while n not in visited:
            visited.add(n)
            n = S[n]
        loop_start = n

        # 閉路の中で最低難度のステージを探す
        min_stage = n
        while True:
            if L[n] < L[min_stage]:
                min_stage = n
            n = S[n]
            if n == loop_start:
                break

        # ループ中最低難度のステージ以外は 1/2 して足す
        for n in ns:
            if n == min_stage:
                total += L[n]
            else:
                total += L[n] / 2
    return total


def main():
    N = int(input())
    L = []
    S = []
    for i in range(N):
        l, s = input().split()
        L.append(int(l))
        S.append(int(s) - 1)
    print(solve(N, L, S))


if __name__ == '__main__':
    main()
0