結果

問題 No.3464 Max and Sum on Grid
コンテスト
ユーザー LyricalMaestro
提出日時 2026-03-01 23:25:46
言語 PyPy3
(7.3.17)
コンパイル:
pypy3 -mpy_compile _filename_
実行:
pypy3 _filename_
結果
TLE  
実行時間 -
コード長 5,186 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 594 ms
コンパイル使用メモリ 77,688 KB
実行使用メモリ 198,568 KB
最終ジャッジ日時 2026-03-01 23:26:49
合計ジャッジ時間 63,205 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other TLE * 10
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

# https://yukicoder.me/problems/no/3464

import math

class Calculator:

    def __init__(
        self,
        A, 
        B,
        sqrt_n,
        cum_rect_values,
        row_values,
        col_values
    ):
        self.A = A
        self.B = B
        self.sqrt_n = sqrt_n
        self.cum_rect_values = cum_rect_values
        self.row_values = row_values
        self.col_values = col_values

    def calculate(self, row_index, col_index):
        if row_index == -1 or col_index == -1:
            return 0

        r_index = row_index // self.sqrt_n
        c_index = col_index // self.sqrt_n
        r_index = min(r_index, self.sqrt_n)
        c_index = min(c_index, self.sqrt_n)

        answer = self.cum_rect_values[r_index][c_index]
        for r in range(r_index * self.sqrt_n, row_index + 1):
            answer += self.row_values[r][c_index]
        for c in range(c_index * self.sqrt_n, col_index + 1):
            answer += self.col_values[c][r_index]
        array = []
        for r in range(r_index * self.sqrt_n, row_index + 1):
            array.append((self.A[r], 0))
        for c in range(c_index * self.sqrt_n, col_index + 1):
            array.append((self.B[c], 1))
        array.sort(key=lambda x : x[0], reverse=True)

        r_rest = row_index + 1 - r_index * self.sqrt_n
        c_rest = col_index + 1 - c_index * self.sqrt_n
        for a, data_type in array:
            if data_type == 0:
                answer += a * c_rest
                r_rest -= 1
            else:
                answer += a * r_rest
                c_rest -= 1
        return answer
        
def prepare_cum_rect_values(sqrt_n, array):
    ## 値を入れる
    rect_values = [[0] * sqrt_n for _ in range(sqrt_n)]
    row_rest_num = [[sqrt_n] * sqrt_n for _ in range(sqrt_n)]
    col_rest_num = [[sqrt_n] * sqrt_n for _ in range(sqrt_n)]
    for a, index, data_type in array:
        if index < sqrt_n * sqrt_n:
            if data_type == 0:
                row_index = index // sqrt_n
                # 行側
                for c_index in range(sqrt_n):   
                    rect_values[row_index][c_index] += a * col_rest_num[row_index][c_index]
                    row_rest_num[row_index][c_index] -= 1
            else:
                col_index = index // sqrt_n
                # 行側
                for r_index in range(sqrt_n):   
                    rect_values[r_index][col_index] += a * row_rest_num[r_index][col_index]
                    col_rest_num[r_index][col_index] -= 1

    ## 2次元累積和
    cum_rect_values = [[0] * (sqrt_n + 1) for _ in range(sqrt_n + 1)]
    for i in range(sqrt_n):
        row = 0
        for j in range(sqrt_n):
            row += rect_values[i][j]
            cum_rect_values[i + 1][j + 1] = row + cum_rect_values[i][j + 1]

    return cum_rect_values

def prepare_row_col_values(N, sqrt_n, array):
    row_values = [[0] * (sqrt_n + 1) for _ in range(N)]
    col_values = [[0] * (sqrt_n + 1) for _ in range(N)]
    row_m_values = [0] * sqrt_n
    col_m_values = [0] * sqrt_n
    row_rest_num = [sqrt_n] * sqrt_n
    col_rest_num = [sqrt_n] * sqrt_n
    for a, index, data_type in array:
        if data_type == 0:
            for j in range(sqrt_n):
                row_values[index][j + 1] += col_m_values[j] + a * col_rest_num[j]
            
            if index < sqrt_n * sqrt_n:
                row_m_values[index // sqrt_n] += a
                row_rest_num[index // sqrt_n] -= 1
        else:
            for j in range(sqrt_n):
                col_values[index][j + 1] += row_m_values[j] + a * row_rest_num[j]

            if index < sqrt_n * sqrt_n:
                col_m_values[index // sqrt_n] += a
                col_rest_num[index // sqrt_n] -= 1

    for i in range(N):
        x = 0
        for j in range(sqrt_n + 1):
            x += row_values[i][j]
            row_values[i][j] = x
        y = 0
        for j in range(sqrt_n + 1):
            y += col_values[i][j]
            col_values[i][j] = y
    return row_values, col_values

def main():
    N, Q = map(int, input().split())
    A = list(map(int, input().split()))
    B = list(map(int, input().split()))
    
    queries = []
    for _ in range(Q):
        l, d, r, u = map(int, input().split())
        queries.append((l - 1, d - 1, r - 1, u - 1))
    
    sqrt_n = int(math.sqrt(N))

    # 値の用意
    array = []
    for i in range(N):
        a = A[i]
        array.append((a, i, 0))
        b = B[i]
        array.append((b, i, 1))
    array.sort(key=lambda x : x[0], reverse=True)

    # 累積和計算
    cum_rect_values = prepare_cum_rect_values(sqrt_n, array)

    # 各値ごとの累積和
    row_values, col_values = prepare_row_col_values(N, sqrt_n, array)

    # クエリに対して答えていく
    calculator = Calculator(A, B, sqrt_n, cum_rect_values, row_values, col_values)
    for l, d, r, u in queries:
        ans1 = calculator.calculate(r, u)
        ans2 = calculator.calculate(r, d - 1)
        ans3 = calculator.calculate(l - 1, u)
        ans4 = calculator.calculate(l - 1, d - 1)

        answer = ans1 - ans2 - ans3 + ans4
        print(answer)








if __name__ == "__main__":
    main()
0