結果
問題 | No.155 生放送とBGM |
ユーザー | rpy3cpp |
提出日時 | 2015-06-03 23:09:22 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 984 ms / 6,000 ms |
コード長 | 2,114 bytes |
コンパイル時間 | 230 ms |
コンパイル使用メモリ | 82,180 KB |
実行使用メモリ | 277,604 KB |
最終ジャッジ日時 | 2024-07-06 14:02:03 |
合計ジャッジ時間 | 4,993 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 660 ms
263,924 KB |
testcase_01 | AC | 697 ms
268,384 KB |
testcase_02 | AC | 984 ms
277,604 KB |
testcase_03 | AC | 35 ms
53,352 KB |
testcase_04 | AC | 98 ms
85,500 KB |
testcase_05 | AC | 34 ms
52,820 KB |
testcase_06 | AC | 117 ms
155,636 KB |
testcase_07 | AC | 41 ms
60,808 KB |
testcase_08 | AC | 39 ms
58,128 KB |
testcase_09 | AC | 41 ms
65,496 KB |
testcase_10 | AC | 39 ms
58,176 KB |
testcase_11 | AC | 43 ms
65,720 KB |
testcase_12 | AC | 44 ms
68,868 KB |
testcase_13 | AC | 40 ms
61,620 KB |
testcase_14 | AC | 584 ms
270,160 KB |
ソースコード
from itertools import zip_longest from math import factorial def read_data(): N, L = map(int, input().split()) Ss = list(input().split()) MS = [] for s in Ss: mm, ss = map(int, s.split(':')) MS.append(mm * 60 + ss) return N, L * 60, MS def solve(N, L, MS): if sum(MS) <= L: return N MS.sort(reverse=False) dp = [[0] * L] dp[0][0] = 1 lower = [L] * (N + 1) upper = [0] * (N + 1) lower[0] = 0 g = get_g(dp, lower, upper, MS) return sum(gi * factorial(n+1) * factorial(N-n-1) for n, gi in enumerate(g)) / factorial(N) def get_g(dp, lower, upper, ms): mid = len(ms) // 2 if mid == 0: lmsi = max(len(dp[0]) - ms[0], 0) return [sum(dpi[lmsi:]) for dpi in dp] ms0 = ms[:mid] ms1 = ms[mid:] dp0, lower0, upper0 = extend_dp(dp, lower, upper, ms0) dp1, lower1, upper1 = extend_dp(dp, lower, upper, ms1) g0 = get_g(dp0, lower0, upper0, ms1) g1 = get_g(dp1, lower1, upper1, ms0) return [g0i + g1i for g0i, g1i in zip_longest(g0, g1, fillvalue=0)] def extend_dp(dp_original, lower_original, upper_original, ms): dp = [dpi[:] for dpi in dp_original] lower = lower_original[:] upper = upper_original[:] L = len(dp[0]) for msj in ms: k = len(dp) - 1 if lower[k] + msj < L: dp.append([0] * L) else: k-= 1 threshold = L - msj for n in range(k, -1, -1): if lower[n] >= threshold: continue new_lower = lower[n] + msj if new_lower < lower[n+1]: lower[n+1] = new_lower new_upper = min(upper[n] + msj, L-1) if new_upper > upper[n+1]: upper[n+1] = new_upper dpn = dp[n] dpn_next = dp[n + 1] end = min(upper[n] + 1, threshold) for newt, dpnt in enumerate(dpn[lower[n]:end], new_lower): if dpnt: dpn_next[newt] += dpnt return dp, lower, upper if __name__ == '__main__': N, L, MS = read_data() print(solve(N, L, MS))