結果
| 問題 | No.3464 Max and Sum on Grid |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2026-02-27 09:41:10 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
AC
|
| 実行時間 | 4,389 ms / 5,000 ms |
| コード長 | 3,802 bytes |
| 記録 | |
| コンパイル時間 | 346 ms |
| コンパイル使用メモリ | 78,220 KB |
| 実行使用メモリ | 181,760 KB |
| 最終ジャッジ日時 | 2026-02-28 13:12:38 |
| 合計ジャッジ時間 | 7,529 ms |
|
ジャッジサーバーID (参考情報) |
judge7 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 10 |
ソースコード
# query: sum[l<=i<=r][d<=j<=u] max(a[i],b[j])
# f(x,y): sum[1<=i<=x][1<=j<=y] max(a[i],b[j])
# query: f(r,u) - f(l-1,u) - f(r,d-1) + f(l-1,d-1)
# f(x+1,y) = f(x,y) + sum[1<=j<=y] max(a[x+1],b[j])
# sum[1<=j<=y] max(a[x+1],b[j]) は bIT で計算できる
# b[1<=j<=y] のうち a[x+1] より小さいものと大きいものの個数と総和を bIT で管理すればよい、 a[x+1]*(個数) + (それより大きいものの総和)
# f(x,y) から f(x+1,y) を計算するのに O(log n) なので、 Mo で O(n sqrt q log n) で解ける
# 2e4 * sqrt(1e5) * log(2e4) はだいたい 9e7 ← マジ?笑
# ↓かなり似ている
# https://atcoder.jp/contests/abc384/tasks/abc384_g
import sys
def solve():
input_data = sys.stdin.read().split()
n = int(input_data[0])
q = int(input_data[1])
a = list(map(int, input_data[2:n+2]))
b = list(map(int, input_data[n+2:2*n+2]))
ptr = 2 * n + 2
queries = []
for i in range(q):
l = int(input_data[ptr]) - 1
d = int(input_data[ptr+1]) - 1
r = int(input_data[ptr+2])
u = int(input_data[ptr+3])
ptr += 4
queries.append((l, d, i * 4))
queries.append((l, u, i * 4 + 1))
queries.append((r, d, i * 4 + 2))
queries.append((r, u, i * 4 + 3))
total_q = q * 4
bsize = max(1, int(n / (total_q ** 0.5)))
queries.sort(key=lambda x: (
x[0] // bsize,
x[1] if (x[0] // bsize) % 2 == 0 else -x[1]
))
maxe = 100005
cnt_a = [0] * maxe
sum_a = [0] * maxe
cnt_b = [0] * maxe
sum_b = [0] * maxe
total_sum_a = 0
total_sum_b = 0
ans = 0
res = [0] * total_q
u, v = 0, 0
for tu, tv, qi in queries:
while v < tv:
idx = b[v] + 1
c = 0
s = 0
while idx > 0:
c += cnt_a[idx]
s += sum_a[idx]
idx -= idx & (-idx)
ans += b[v] * c + (total_sum_a - s)
total_sum_b += b[v]
idx = b[v] + 1
while idx < maxe:
cnt_b[idx] += 1
sum_b[idx] += b[v]
idx += idx & (-idx)
v += 1
while u > tu:
u -= 1
total_sum_a -= a[u]
idx = a[u] + 1
while idx < maxe:
cnt_a[idx] -= 1
sum_a[idx] -= a[u]
idx += idx & (-idx)
idx = a[u] + 1
c = 0
s = 0
while idx > 0:
c += cnt_b[idx]
s += sum_b[idx]
idx -= idx & (-idx)
ans -= a[u] * c + (total_sum_b - s)
while v > tv:
v -= 1
total_sum_b -= b[v]
idx = b[v] + 1
while idx < maxe:
cnt_b[idx] -= 1
sum_b[idx] -= b[v]
idx += idx & (-idx)
idx = b[v] + 1
c = 0
s = 0
while idx > 0:
c += cnt_a[idx]
s += sum_a[idx]
idx -= idx & (-idx)
ans -= b[v] * c + (total_sum_a - s)
while u < tu:
idx = a[u] + 1
c = 0
s = 0
while idx > 0:
c += cnt_b[idx]
s += sum_b[idx]
idx -= idx & (-idx)
ans += a[u] * c + (total_sum_b - s)
total_sum_a += a[u]
idx = a[u] + 1
while idx < maxe:
cnt_a[idx] += 1
sum_a[idx] += a[u]
idx += idx & (-idx)
u += 1
res[qi] = ans
out = []
for i in range(q):
out.append(str(res[i*4 + 3] - res[i*4 + 2] - res[i*4 + 1] + res[i*4]))
sys.stdout.write('\n'.join(out) + '\n')
if __name__ == '__main__':
solve()