結果

問題 No.3122 Median of Medians of Division
ユーザー Naru820
提出日時 2025-04-17 11:22:45
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 5,222 bytes
コンパイル時間 550 ms
コンパイル使用メモリ 82,384 KB
実行使用メモリ 270,812 KB
最終ジャッジ日時 2025-04-17 11:23:00
合計ジャッジ時間 15,002 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 1
other TLE * 1 -- * 39
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
import bisect

def main():
    sys.setrecursionlimit(1 << 25)
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr += 1
    Q = int(input[ptr])
    ptr += 1
    A = list(map(int, input[ptr:ptr+N]))
    ptr += N

    # Segment Tree for max value and its count
    class SegTreeMax:
        def __init__(self, data):
            self.n = len(data)
            self.size = 1
            while self.size < self.n:
                self.size <<= 1
            self.max_val = [0] * (2 * self.size)
            self.count = [0] * (2 * self.size)
            for i in range(self.n):
                self.max_val[self.size + i] = data[i]
                self.count[self.size + i] = 1
            for i in range(self.size - 1, 0, -1):
                left = 2 * i
                right = 2 * i + 1
                if self.max_val[left] > self.max_val[right]:
                    self.max_val[i] = self.max_val[left]
                    self.count[i] = self.count[left]
                elif self.max_val[right] > self.max_val[left]:
                    self.max_val[i] = self.max_val[right]
                    self.count[i] = self.count[right]
                else:
                    self.max_val[i] = self.max_val[left]
                    self.count[i] = self.count[left] + self.count[right]
        
        def update(self, pos, val):
            pos += self.size
            self.max_val[pos] = val
            pos >>= 1
            while pos >= 1:
                left = 2 * pos
                right = 2 * pos + 1
                if self.max_val[left] > self.max_val[right]:
                    self.max_val[pos] = self.max_val[left]
                    self.count[pos] = self.count[left]
                elif self.max_val[right] > self.max_val[left]:
                    self.max_val[pos] = self.max_val[right]
                    self.count[pos] = self.count[right]
                else:
                    self.max_val[pos] = self.max_val[left]
                    self.count[pos] = self.count[left] + self.count[right]
                pos >>= 1
        
        def query_max(self, l, r):
            res = -float('inf')
            cnt = 0
            l += self.size
            r += self.size
            while l <= r:
                if l % 2 == 1:
                    if self.max_val[l] > res:
                        res = self.max_val[l]
                        cnt = self.count[l]
                    elif self.max_val[l] == res:
                        cnt += self.count[l]
                    l += 1
                if r % 2 == 0:
                    if self.max_val[r] > res:
                        res = self.max_val[r]
                        cnt = self.count[r]
                    elif self.max_val[r] == res:
                        cnt += self.count[r]
                    r -= 1
                l >>= 1
                r >>= 1
            return res, cnt

    seg_max = SegTreeMax(A)

    # Sqrt Decomposition for k-th element
    bucket_size = 450  # Adjust based on constraints
    buckets = []
    sorted_buckets = []
    for i in range(N):
        if i % bucket_size == 0:
            buckets.append([])
            sorted_buckets.append([])
        buckets[-1].append(A[i])
        sorted_buckets[-1].append(A[i])
    for i in range(len(sorted_buckets)):
        sorted_buckets[i].sort()

    def update_sqrt(pos, old_val, new_val):
        bucket_idx = pos // bucket_size
        idx_in_bucket = pos % bucket_size
        buckets[bucket_idx][idx_in_bucket] = new_val
        sorted_buckets[bucket_idx] = sorted(buckets[bucket_idx])

    def query_kth(l, r, k):
        elements = []
        # Left partial bucket
        left_bucket = l // bucket_size
        if l % bucket_size != 0:
            end = min((left_bucket + 1) * bucket_size - 1, r)
            for i in range(l, end + 1):
                elements.append(A[i])
            left_bucket += 1
        # Right partial bucket
        right_bucket = r // bucket_size
        if (r + 1) % bucket_size != 0 and left_bucket <= right_bucket:
            start = right_bucket * bucket_size
            for i in range(start, r + 1):
                elements.append(A[i])
            right_bucket -= 1
        # Process full buckets
        for b in range(left_bucket, right_bucket + 1):
            elements.extend(sorted_buckets[b])
        # Sort and find k-th element
        elements.sort()
        return elements[k]

    for _ in range(Q):
        query = input[ptr]
        ptr += 1
        if query == '1':
            i = int(input[ptr]) - 1
            ptr += 1
            x = int(input[ptr])
            ptr += 1
            old_val = A[i]
            A[i] = x
            seg_max.update(i, x)
            update_sqrt(i, old_val, x)
        else:
            l = int(input[ptr]) - 1
            ptr += 1
            r = int(input[ptr]) - 1
            ptr += 1
            max_val, cnt = seg_max.query_max(l, r)
            if cnt >= 2:
                print(max_val)
            else:
                length = r - l + 1
                k = (length + 1) // 2 - 1  # 0-based index
                median = query_kth(l, r, k)
                print(median)

if __name__ == '__main__':
    main()
0