結果

問題 No.3464 Max and Sum on Grid
コンテスト
ユーザー まぬお
提出日時 2026-02-28 16:29:48
言語 PyPy3
(7.3.17)
コンパイル:
pypy3 -mpy_compile _filename_
実行:
pypy3 _filename_
結果
TLE  
実行時間 -
コード長 6,591 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 347 ms
コンパイル使用メモリ 77,660 KB
実行使用メモリ 206,196 KB
最終ジャッジ日時 2026-02-28 16:30:38
合計ジャッジ時間 7,337 ms
ジャッジサーバーID
(参考情報)
judge7 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other TLE * 1 -- * 9
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

from collections import deque, defaultdict, Counter
from bisect import bisect_left, bisect_right
from itertools import permutations, combinations, groupby
from heapq import heappop, heappush
import math, sys
input = lambda: sys.stdin.readline().rstrip("\r\n")
def printl(li, sep=" "): print(sep.join(map(str, li)))
def yn(flag): print(Yes if flag else No)
_int = lambda x: int(x)-1
MOD = 998244353 #10**9+7
INF = 1<<60
Yes, No = "Yes", "No"
def ctypes(li, types):
    assert len(li) == len(types)
    return [t(a) for a, t in zip(li, types)]
def tinput(*types):
    li = input().split()
    return ctypes(li, types)
def qinput(*types):
    li = input().split()
    t = int(li[0])
    if len(li) == 1: return t, types[t-1][0](li[1])
    else: return t, ctypes(li[1:], types[t-1])

class MoSolver:
    n: int
    q: int
    qs: list[tuple[int, int, int]]
    
    def __init__(self, n: int):
        self.n = n
        self.q = 0
        self.qs = []

    def add_query(self, l: int, r: int): # [l, r)
        self.qs.append((l, r, self.q))
        self.q += 1

    def solve(self, add_left, add_right, erase_left, erase_right, answer) -> list[int]:
        b = int(self.n / (self.q**0.5 * 0.8)) + 1
        self.qs.sort(key=lambda x: (x[0] // b, x[1] if (x[0] // b) % 2 == 0 else -x[1]))

        ans = [0]*self.q
        cur_l, cur_r = 0, 0
        for l, r, idx in self.qs:
            while cur_l > l: # [l, r) -> [l-1, r)
                cur_l -= 1
                add_left(cur_l, cur_r)
            while cur_r < r: # [l, r) -> [l, r+1)
                add_right(cur_l, cur_r)
                cur_r += 1
            while cur_l < l: # [l, r) -> [l+1, r)
                erase_left(cur_l, cur_r)
                cur_l += 1
            while cur_r > r: # [l, r) -> [l, r-1)
                cur_r -= 1
                erase_right(cur_l, cur_r)
            ans[idx] = answer()
        return ans
    
from bisect import bisect_right, bisect_left
from itertools import accumulate

class MergeSortTree:
    def __init__(self, v: list[int]):
        self._n = len(v)
        self._log = (self._n - 1).bit_length()
        self._size = 1 << self._log
        self._d = [[] for _ in range(2 * self._size)]
        self._cum = [[] for _ in range(2 * self._size)]

        for i, x in enumerate(v):
            self._d[self._size + i] = [x]
        
        for i in range(self._size - 1, 0, -1):
            self._d[i] = sorted(self._d[2 * i] + self._d[2 * i + 1])
        
        for i in range(1, 2 * self._size):
            if self._d[i]:
                self._cum[i] = list(accumulate(self._d[i], initial=0))

    def _query_le(self, left: int, right: int, x: int, mode="count"):
        """
        内部共通処理
        mode: "count" (x以下の個数), "sum" (x以下の総和)
        """
        res = 0
        left += self._size
        right += self._size
        while left < right:
            if left & 1:
                idx = bisect_right(self._d[left], x)
                res += idx if mode == "count" else self._cum[left][idx]
                left += 1
            if right & 1:
                right -= 1
                idx = bisect_right(self._d[right], x)
                res += idx if mode == "count" else self._cum[right][idx]
            left >>= 1
            right >>= 1
        return res

    def count_le(self, l, r, x): return self._query_le(l, r, x, "count")
    def sum_le(self, l, r, x):   return self._query_le(l, r, x, "sum")
    def count_lt(self, l, r, x): return self._query_le(l, r, x-1, "count")
    def sum_lt(self, l, r, x):   return self._query_le(l, r, x-1, "sum")

    def _query_ge(self, left: int, right: int, x: int, mode="count"):
        """
        内部共通処理
        mode: "count" (x以下の個数), "sum" (x以下の総和)
        """
        res = 0
        left += self._size
        right += self._size
        while left < right:
            if left & 1:
                idx = bisect_left(self._d[left], x)
                cum_size = len(self._cum[left])-1
                res += cum_size - idx if mode == "count" else self._cum[left][cum_size] - self._cum[left][idx]
                left += 1
            if right & 1:
                right -= 1
                idx = bisect_left(self._d[right], x)
                cum_size = len(self._cum[right])-1
                res += cum_size - idx if mode == "count" else self._cum[right][cum_size] - self._cum[right][idx]
            left >>= 1
            right >>= 1
        return res
    
    def count_ge(self, l, r, x): return self._query_ge(l, r, x, "count")
    def sum_ge(self, l, r, x):   return self._query_ge(l, r, x, "sum")
    def count_gt(self, l, r, x): return self._query_ge(l, r, x+1, "count")
    def sum_gt(self, l, r, x):   return self._query_ge(l, r, x+1, "sum")

    def count_range(self, l, r, low, high):
        return self.count_le(l, r, high) - self.count_le(l, r, low - 1)
    def sum_range(self, l, r, low, high):
        return self.sum_le(l, r, high) - self.sum_le(l, r, low - 1)

    def kth_smallest(self, l: int, r: int, k: int) -> int:
        """
        [l, r) の中で k 番目 (1-indexed) に小さい値を返す
        """
        assert 0 < k <= (r - l)
        
        low = self._d[1][0]
        high = self._d[1][-1]
        
        ans = high
        while low <= high:
            mid = (low + high) // 2
            if self.count_le(l, r, mid) >= k:
                ans = mid
                high = mid - 1
            else:
                low = mid + 1
        return ans
    
N, Q = map(int, input().split())
A = list(map(int, input().split()))
B = list(map(int, input().split()))
MA = MergeSortTree(A)
MB = MergeSortTree(B)
Mo = MoSolver(N)
for _ in range(Q):
    l, d, r, u = tinput(_int, _int, int, int)
    Mo.add_query(d, l)
    Mo.add_query(u, l)
    Mo.add_query(d, r)
    Mo.add_query(u, r)
ans = 0
def add_left(l, r):
    global ans
    cnt = MA.count_le(0, r, B[l])
    ans -= MA.sum_ge(0, r, 0) - MA.sum_le(0, r, B[l]) + cnt*B[l]

def add_right(l, r):
    global ans
    cnt = MB.count_le(0, l, A[r])
    ans += MB.sum_ge(0, l, 0) - MB.sum_le(0, l, A[r]) + cnt*A[r]

def erase_left(l, r):
    global ans
    cnt = MA.count_le(0, r, B[l])
    ans += MA.sum_ge(0, r, 0) - MA.sum_le(0, r, B[l]) + cnt*B[l]
    
def erase_right(l, r):
    global ans
    cnt = MB.count_le(0, l, A[r])
    ans -= MB.sum_ge(0, l, 0) - MB.sum_le(0, l, A[r]) + cnt*A[r]

def answer(): return ans
res = Mo.solve(add_left, add_right, erase_left, erase_right, answer)
for i in range(Q):
    print(res[i*4+3]-res[i*4+2]-res[i*4+1]+res[i*4])
0