結果
問題 |
No.3303 Heal Slimes 2
|
ユーザー |
![]() |
提出日時 | 2025-09-27 08:11:50 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 2,127 ms / 4,000 ms |
コード長 | 2,499 bytes |
コンパイル時間 | 474 ms |
コンパイル使用メモリ | 82,748 KB |
実行使用メモリ | 151,852 KB |
最終ジャッジ日時 | 2025-10-06 12:44:19 |
合計ジャッジ時間 | 39,022 ms |
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
other | AC * 31 |
ソースコード
import bisect class BIT: def __init__(self, n): self.n = n self.data = [0] * (n+1) def add(self, idx, x): """idx: 0-indexed""" idx += 1 while idx <= self.n: self.data[idx] += x idx += idx & -idx def sum(self, idx): """sum of [0, idx)""" res = 0 while idx > 0: res += self.data[idx] idx -= idx & -idx return res def range_sum(self, l, r): """sum of [l, r)""" return self.sum(r) - self.sum(l) def get(self, idx): return self.range_sum(idx, idx+1) def solve(n, k, d, h): sorted_h = set(h) for v in h: if v >= d: sorted_h.add(v-d) sorted_h = sorted(list(sorted_h)) m = len(sorted_h) st = BIT(m) cnt = BIT(m) di = {v: i for i, v in enumerate(sorted_h)} di_inv = {i: v for v, i in di.items()} for v in h[:k]: st.add(di[v], v) cnt.add(di[v], 1) ans = 1 << 60 for i in range(n-k+1): l, r = 0, m while r - l > 1: piv = (l + r) // 2 x = di_inv[piv] u = bisect.bisect_right(sorted_h, x+d) if cnt.range_sum(0, piv) > cnt.range_sum(u, m): r = piv else: l = piv for l in range(l, min(m, l+2)): x = di_inv[l] y = x + d u = bisect.bisect_right(sorted_h, y) res = (cnt.range_sum(0, l) * x - st.range_sum(0, l)) + (st.range_sum(u, m) - y * cnt.range_sum(u, m)) ans = min(ans, res) # print((x, y), (l, u)) # print(cnt.range_sum(0, l), cnt.range_sum(u, m)) if i < n - k: st.add(di[h[i]], -h[i]) cnt.add(di[h[i]], -1) st.add(di[h[i+k]], h[i+k]) cnt.add(di[h[i+k]], 1) return ans def naive(n, k, d, h): ans = 1 << 60 for i in range(n-k+1): ls = h[i:i+k] for x in range(min(ls), max(ls)+1): res = 0 for v in ls: if v < x: res += x - v if x + d < v: res += v - (x + d) # if res == 1: # print(ls, x) ans = min(ans, res) return ans n, k, d = map(int, input().split()) h = list(map(int, input().split())) print(solve(n, k, d, h)) # print(naive(n, k, d, h)) exit() import random n = 5 for _ in range(200): k = random.randint(2, n) d = random.randint(0, 10) h = [random.randint(0, 10) for _ in range(n)] try: solve(n, k, d, h) except: print(n, k, d) print(*h) assert 0 break if solve(n, k, d, h) != naive(n, k, d, h): print(n, k, d) print(*h) break