結果
| 問題 | 
                            No.704 ゴミ拾い Medium
                             | 
                    
| コンテスト | |
| ユーザー | 
                             lam6er
                         | 
                    
| 提出日時 | 2025-04-15 21:00:19 | 
| 言語 | PyPy3  (7.3.15)  | 
                    
| 結果 | 
                             
                                AC
                                 
                             
                            
                         | 
                    
| 実行時間 | 803 ms / 1,500 ms | 
| コード長 | 2,022 bytes | 
| コンパイル時間 | 247 ms | 
| コンパイル使用メモリ | 82,168 KB | 
| 実行使用メモリ | 251,536 KB | 
| 最終ジャッジ日時 | 2025-04-15 21:06:02 | 
| 合計ジャッジ時間 | 18,486 ms | 
| 
                            ジャッジサーバーID (参考情報)  | 
                        judge5 / judge4 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 4 | 
| other | AC * 44 | 
ソースコード
import bisect
class SegmentTree:
    def __init__(self, size):
        self.n = 1
        while self.n < size:
            self.n <<= 1
        self.size = size
        self.tree = [float('inf')] * (2 * self.n)
    
    def update(self, pos, value):
        pos += self.n
        self.tree[pos] = value
        while pos > 1:
            pos >>= 1
            self.tree[pos] = min(self.tree[2*pos], self.tree[2*pos+1])
    
    def query(self, l, r):
        res = float('inf')
        l += self.n
        r += self.n
        while l <= r:
            if l % 2 == 1:
                res = min(res, self.tree[l])
                l += 1
            if r % 2 == 0:
                res = min(res, self.tree[r])
                r -= 1
            l >>= 1
            r >>= 1
        return res
def main():
    import sys
    input = sys.stdin.read().split()
    ptr = 0
    n = int(input[ptr])
    ptr +=1
    a = list(map(int, input[ptr:ptr+n]))
    ptr +=n
    x = list(map(int, input[ptr:ptr+n]))
    ptr +=n
    y = list(map(int, input[ptr:ptr+n]))
    ptr +=n
    
    st_A = SegmentTree(n)
    st_B = SegmentTree(n)
    
    dp_prev = 0  # DP[-1] =0
    
    for k in range(n):
        current_a = a[k]
        current_x = x[k]
        current_y = y[k]
        
        valA = dp_prev + current_y - current_x
        valB = dp_prev + current_y + current_x
        
        st_A.update(k, valA)
        st_B.update(k, valB)
        
        # Find m: maximum j where x[j] <= current_a, j <=k
        # search in x[0..k]
        m = bisect.bisect_right(x, current_a, 0, k+1) -1
        
        minA = st_A.query(0, m) if m >=0 else float('inf')
        minB = st_B.query(m+1, k) if (m+1) <=k else float('inf')
        
        option1 = minA + current_a if minA != float('inf') else float('inf')
        option2 = minB - current_a if minB != float('inf') else float('inf')
        dp_current = min(option1, option2)
        
        dp_prev = dp_current
    
    print(dp_prev)
    
if __name__ == '__main__':
    main()
            
            
            
        
            
lam6er