結果

問題 No.704 ゴミ拾い Medium
ユーザー lam6er
提出日時 2025-04-15 21:11:47
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,379 bytes
コンパイル時間 649 ms
コンパイル使用メモリ 81,800 KB
実行使用メモリ 220,820 KB
最終ジャッジ日時 2025-04-15 21:17:48
合計ジャッジ時間 15,753 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 4
other AC * 8 WA * 36
権限があれば一括ダウンロードができます

ソースコード

diff #

import bisect

def main():
    import sys
    input = sys.stdin.read
    data = input().split()
    idx = 0
    n = int(data[idx])
    idx += 1
    a = list(map(int, data[idx:idx+n]))
    idx += n
    x = list(map(int, data[idx:idx+n]))
    idx += n
    y = list(map(int, data[idx:idx+n]))
    idx += n

    m = [0] * n
    for k in range(n):
        m_k = bisect.bisect_right(x, a[k]) - 1
        m[k] = m_k

    class SegmentTree:
        def __init__(self, size):
            self.n = 1
            while self.n < size:
                self.n <<= 1
            self.size = size
            self.tree = [float('inf')] * (2 * self.n)

        def update(self, pos, value):
            pos += self.n
            if self.tree[pos] > value:
                self.tree[pos] = value
                while pos > 1:
                    pos >>= 1
                    new_val = min(self.tree[2 * pos], self.tree[2 * pos + 1])
                    if self.tree[pos] == new_val:
                        break
                    self.tree[pos] = new_val

        def query(self, l, r):
            res = float('inf')
            l += self.n
            r += self.n
            while l < r:
                if l % 2 == 1:
                    res = min(res, self.tree[l])
                    l += 1
                if r % 2 == 1:
                    r -= 1
                    res = min(res, self.tree[r])
                l >>= 1
                r >>= 1
            return res

    st = SegmentTree(n)
    current_min1 = float('inf')
    last_m = -1
    dp = [0] * (n + 1)  # dp[0] = 0, dp[1..n] are the results

    for k in range(n):
        j = k
        val = dp[j] + x[j] + y[j]
        st.update(j, val)

        m_k = m[k]

        start = last_m + 1
        end = m_k
        for jj in range(start, end + 1):
            current_val = dp[jj] + (y[jj] - x[jj])
            if current_val < current_min1:
                current_min1 = current_val

        last_m = m_k

        left = m_k + 1
        right = k
        if left > right:
            min2 = float('inf')
        else:
            min2 = st.query(left, right + 1)

        option1 = current_min1 + a[k] if current_min1 != float('inf') else float('inf')
        option2 = min2 - a[k] if min2 != float('inf') else float('inf')
        dp[k + 1] = min(option1, option2)

    print(dp[n])

if __name__ == "__main__":
    main()
0