結果

問題 No.3464 Max and Sum on Grid
コンテスト
ユーザー tyawanmusi
提出日時 2026-02-27 09:12:42
言語 PyPy3
(7.3.17)
コンパイル:
pypy3 -mpy_compile _filename_
実行:
pypy3 _filename_
結果
TLE  
実行時間 -
コード長 4,012 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 3,966 ms
コンパイル使用メモリ 77,908 KB
実行使用メモリ 56,588 KB
最終ジャッジ日時 2026-02-28 13:11:48
合計ジャッジ時間 7,907 ms
ジャッジサーバーID
(参考情報)
judge4 / judge7
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other TLE * 1 -- * 9
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

# query: sum[l<=i<=r][d<=j<=u] max(a[i],b[j])
# f(x,y): sum[1<=i<=x][1<=j<=y] max(a[i],b[j])
# query: f(r,u) - f(l-1,u) - f(r,d-1) + f(l-1,d-1)
# f(x+1,y) = f(x,y) + sum[1<=j<=y] max(a[x+1],b[j])
# sum[1<=j<=y] max(a[x+1],b[j]) は BIT で計算できる
# b[1<=j<=y] のうち a[x+1] より小さいものと大きいものの個数と総和を BIT で管理すればよい、 a[x+1]*(個数) + (それより大きいものの総和)
# f(x,y) から f(x+1,y) を計算するのに O(log n) なので、 Mo で O(N sqrt Q log N) で解ける
# 2e4 * sqrt(1e5) * log(2e4) はだいたい 9e7 ← マジ?笑
# ↓かなり似ている
# https://atcoder.jp/contests/abc384/tasks/abc384_g

# BinaryIndexedTree
# add(i, x): a[i] += x
# sum(i, j): a[i:j) の和
class BinaryIndexedTree:
  def __init__(self, n):
    self.bit = [0] * n

  def add(self, i, x):
    i += 1
    while i <= len(self.bit):
      self.bit[i-1] += x
      i += i & -i

  def sum_sub(self, i):
    a = 0
    while i:
      a += self.bit[i-1]
      i -= i & -i
    return a

  def sum(self, i, j):
    return self.sum_sub(j) - self.sum_sub(i)

# Mo's algorithm
# 1<=u,v<=N のクエリ (u,v) が Q 個あり、f(u,v) から f(u+1,v), f(u-1,v), f(u,v+1), f(u,v-1) を十分高速に計算できるとき、全クエリの答えを O(N sqrt Q) で求めるアルゴリズム
# u_plus では f(u,v) の状態から f(u+1,v) の状態へ遷移する。変数の更新や ans の更新を行う。
# mo = MO(N)
# for u, v in queries:
#     mo.add_query(u, v)
# for ans in mo.solve(): 
#     print(ans)
class Mo:
    def __init__(self, n):
        self.n = n
        self.queries = []
        # 初期状態をよしなに
        self.ans = 0

    def u_plus(self, u, v):
        self.ans += a[u] * cntb.sum(0, a[u] + 1) + sumb.sum(a[u] + 1, maxe)
        cnta.add(a[u], 1)
        suma.add(a[u], a[u])
        pass

    def u_minus(self, u, v):
        # f(u, v) -> f(u-1, v)
        self.ans -= a[u-1] * cntb.sum(0, a[u-1] + 1) + sumb.sum(a[u-1] + 1, maxe)
        cnta.add(a[u-1], -1)
        suma.add(a[u-1], -a[u-1])
        pass

    def v_plus(self, u, v):
        # f(u, v) -> f(u, v+1)
        self.ans += b[v] * cnta.sum(0, b[v] + 1) + suma.sum(b[v] + 1, maxe)
        cntb.add(b[v], 1)
        sumb.add(b[v], b[v])
        pass

    def v_minus(self, u, v):
        # f(u, v) -> f(u, v-1)
        self.ans -= b[v-1] * cnta.sum(0, b[v-1] + 1) + suma.sum(b[v-1] + 1, maxe)
        cntb.add(b[v-1], -1)
        sumb.add(b[v-1], -b[v-1])
        pass

    def get_ans(self):
        return self.ans

    def add_query(self, u, v):
        self.queries.append((u, v, len(self.queries)))

    def solve(self):
        q = len(self.queries)
        if q == 0:
            return []
        bsize = max(1, int(self.n / q**0.5))
        self.queries.sort(key=lambda x: (
            x[0] // bsize,
            x[1] if (x[0] // bsize) % 2 == 0 else -x[1]
        ))
        ans = [0] * q
        u, v = 0, 0
        for target_u, target_v, q_i in self.queries:
            while v < target_v:
                self.v_plus(u, v)
                v += 1
            while u > target_u:
                self.u_minus(u, v)
                u -= 1
            while v > target_v:
                self.v_minus(u, v)
                v -= 1
            while u < target_u:
                self.u_plus(u, v)
                u += 1
            ans[q_i] = self.get_ans()
        return ans

import sys
input = sys.stdin.readline
maxe = 10**5 + 1
n, q = map(int, input().split())
a = list(map(int, input().split()))
b = list(map(int, input().split()))
cnta = BinaryIndexedTree(maxe)
suma = BinaryIndexedTree(maxe)
cntb = BinaryIndexedTree(maxe)
sumb = BinaryIndexedTree(maxe)
mo = Mo(n)
for _ in range(q):
    l, d, r, u = map(int, input().split())
    mo.add_query(l-1, d-1)
    mo.add_query(l-1, u)
    mo.add_query(r, d-1)
    mo.add_query(r, u)
res = mo.solve()
for i in range(0, len(res), 4):
    ans = res[i+3] - res[i+2] - res[i+1] + res[i]
    print(ans)
0