結果

問題 No.3464 Max and Sum on Grid
コンテスト
ユーザー tyawanmusi
提出日時 2026-02-27 09:41:10
言語 PyPy3
(7.3.17)
コンパイル:
pypy3 -mpy_compile _filename_
実行:
pypy3 _filename_
結果
AC  
実行時間 4,389 ms / 5,000 ms
コード長 3,802 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 346 ms
コンパイル使用メモリ 78,220 KB
実行使用メモリ 181,760 KB
最終ジャッジ日時 2026-02-28 13:12:38
合計ジャッジ時間 7,529 ms
ジャッジサーバーID
(参考情報)
judge7 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 10
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

# query: sum[l<=i<=r][d<=j<=u] max(a[i],b[j])
# f(x,y): sum[1<=i<=x][1<=j<=y] max(a[i],b[j])
# query: f(r,u) - f(l-1,u) - f(r,d-1) + f(l-1,d-1)
# f(x+1,y) = f(x,y) + sum[1<=j<=y] max(a[x+1],b[j])
# sum[1<=j<=y] max(a[x+1],b[j]) は bIT で計算できる
# b[1<=j<=y] のうち a[x+1] より小さいものと大きいものの個数と総和を bIT で管理すればよい、 a[x+1]*(個数) + (それより大きいものの総和)
# f(x,y) から f(x+1,y) を計算するのに O(log n) なので、 Mo で O(n sqrt q log n) で解ける
# 2e4 * sqrt(1e5) * log(2e4) はだいたい 9e7 ← マジ?笑
# ↓かなり似ている
# https://atcoder.jp/contests/abc384/tasks/abc384_g

import sys

def solve():
    input_data = sys.stdin.read().split()
    n = int(input_data[0])
    q = int(input_data[1])
    a = list(map(int, input_data[2:n+2]))
    b = list(map(int, input_data[n+2:2*n+2]))
    ptr = 2 * n + 2
    queries = []
    for i in range(q):
        l = int(input_data[ptr]) - 1
        d = int(input_data[ptr+1]) - 1
        r = int(input_data[ptr+2])
        u = int(input_data[ptr+3])
        ptr += 4
        queries.append((l, d, i * 4))
        queries.append((l, u, i * 4 + 1))
        queries.append((r, d, i * 4 + 2))
        queries.append((r, u, i * 4 + 3))
    total_q = q * 4
    bsize = max(1, int(n / (total_q ** 0.5)))
    queries.sort(key=lambda x: (
        x[0] // bsize,
        x[1] if (x[0] // bsize) % 2 == 0 else -x[1]
    ))
    maxe = 100005
    cnt_a = [0] * maxe
    sum_a = [0] * maxe
    cnt_b = [0] * maxe
    sum_b = [0] * maxe
    total_sum_a = 0
    total_sum_b = 0
    ans = 0
    res = [0] * total_q
    u, v = 0, 0
    for tu, tv, qi in queries:
        while v < tv:
            idx = b[v] + 1
            c = 0
            s = 0
            while idx > 0:
                c += cnt_a[idx]
                s += sum_a[idx]
                idx -= idx & (-idx)
            ans += b[v] * c + (total_sum_a - s)
            total_sum_b += b[v]
            idx = b[v] + 1
            while idx < maxe:
                cnt_b[idx] += 1
                sum_b[idx] += b[v]
                idx += idx & (-idx)
            v += 1
        while u > tu:
            u -= 1
            total_sum_a -= a[u]
            idx = a[u] + 1
            while idx < maxe:
                cnt_a[idx] -= 1
                sum_a[idx] -= a[u]
                idx += idx & (-idx)
            idx = a[u] + 1
            c = 0
            s = 0
            while idx > 0:
                c += cnt_b[idx]
                s += sum_b[idx]
                idx -= idx & (-idx)
            ans -= a[u] * c + (total_sum_b - s)
        while v > tv:
            v -= 1
            total_sum_b -= b[v]
            idx = b[v] + 1
            while idx < maxe:
                cnt_b[idx] -= 1
                sum_b[idx] -= b[v]
                idx += idx & (-idx)
            idx = b[v] + 1
            c = 0
            s = 0
            while idx > 0:
                c += cnt_a[idx]
                s += sum_a[idx]
                idx -= idx & (-idx)
            ans -= b[v] * c + (total_sum_a - s)
        while u < tu:
            idx = a[u] + 1
            c = 0
            s = 0
            while idx > 0:
                c += cnt_b[idx]
                s += sum_b[idx]
                idx -= idx & (-idx)
            ans += a[u] * c + (total_sum_b - s)
            total_sum_a += a[u]
            idx = a[u] + 1
            while idx < maxe:
                cnt_a[idx] += 1
                sum_a[idx] += a[u]
                idx += idx & (-idx)
            u += 1
        res[qi] = ans
    out = []
    for i in range(q):
        out.append(str(res[i*4 + 3] - res[i*4 + 2] - res[i*4 + 1] + res[i*4]))
    sys.stdout.write('\n'.join(out) + '\n')

if __name__ == '__main__':
    solve()
0