結果

問題 No.1649 Manhattan Square
ユーザー LyricalMaestroLyricalMaestro
提出日時 2024-11-12 00:34:31
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,284 ms / 3,000 ms
コード長 3,936 bytes
コンパイル時間 1,216 ms
コンパイル使用メモリ 82,124 KB
実行使用メモリ 207,028 KB
最終ジャッジ日時 2024-11-12 00:35:26
合計ジャッジ時間 47,465 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 43
権限があれば一括ダウンロードができます

ソースコード

diff #

## https://yukicoder.me/problems/no/1649


MOD = 998244353

class BinaryIndexTree:
    """
    フェニック木(BinaryIndexTree)の基本的な機能を実装したクラス
    """
    def __init__(self, size):
        self.size = size
        self.array = [0] * (size + 1)
    
    def add(self, x, a):
        index = x
        while index <= self.size:
            self.array[index] += a
            self.array[index] %= MOD
            index += index & (-index)
    
    def sum(self, x):
        index = x
        ans = 0
        while index > 0:
            ans += self.array[index]
            ans %= MOD
            index -= index & (-index)
        return ans

    def least_upper_bound(self, value):
        if self.sum(self.size) < value:
            return -1
        elif value <= 0:
            return 0

        m = 1
        while m < self.size:
            m *= 2

        k = 0
        k_sum = 0
        while m > 0:
            k0 = k + m
            if k0 < self.size:
                if k_sum + self.array[k0] < value:
                    k_sum += self.array[k0]
                    k += m
            m //= 2
        if k < self.size:
            return k + 1
        else:
            return -1


def pow2_sum(N, value_list):
    pow2sum = 0
    for v in value_list:
        pow2sum += (v * v) % MOD
        pow2sum %= MOD

    pow1sum = 0
    for v in value_list:
        pow1sum += v
        pow1sum %= MOD
    answer = (N * pow2sum) % MOD
    answer -= (pow1sum * pow1sum) % MOD
    answer %= MOD
    return answer



def main():
    N = int(input())
    xy = []
    for _ in range(N):
        x, y = map(int, input().split())
        xy.append((x, y))
    
    # |xi - xj|^2 の和について
    x_list = [x for x, _ in xy]
    ans_x = pow2_sum(N, x_list)

    y_list = [y for _, y in xy]
    ans_y = pow2_sum(N, y_list)

    # |xi - xj||yi - yj|の和について
    x_map = {}
    for x, y in xy:
        if x not in x_map:
            x_map[x] = []
        x_map[x].append(y)
    
    x_array = [(x, y_array) for x, y_array in x_map.items()]
    x_array.sort(key=lambda x : x[0])

    # yについての座標圧縮
    y_set = set(y_list)
    y_list = list(y_set)
    y_list.sort()
    y_map = {}
    for i, y in enumerate(y_list):
        y_map[y] = i + 1

    bit_count = BinaryIndexTree(len(y_list))
    bit_x_sum = BinaryIndexTree(len(y_list))
    bit_y_sum = BinaryIndexTree(len(y_list))
    bit_xy_sum = BinaryIndexTree(len(y_list))
    ans_xy = 0
    for x, y_array in x_array:
        for y in y_array:
            # yj <= yiとなるもの
            n = bit_count.sum(y_map[y])
            sy = bit_y_sum.sum(y_map[y])
            sx = bit_x_sum.sum(y_map[y])
            sxy = bit_xy_sum.sum(y_map[y])

            ans1 = (((n * x) % MOD) * y) % MOD
            ans1 -= (sy * x) % MOD
            ans1 %= MOD
            ans1 -= (y * sx) % MOD
            ans1 %= MOD
            ans1 += sxy
            ans1 %= MOD

            # yj > yiとなるもの
            n = (bit_count.sum(bit_count.size) - n) % MOD
            sy = (bit_y_sum.sum(bit_y_sum.size) - sy) % MOD
            sx = (bit_x_sum.sum(bit_x_sum.size) - sx) % MOD
            sxy = (bit_xy_sum.sum(bit_xy_sum.size) - sxy) % MOD
            
            ans2 = (((n * x) % MOD) * y) % MOD
            ans2 *= -1
            ans2 %= MOD
            ans2 += (sy * x) % MOD
            ans2 %= MOD
            ans2 += (y * sx) % MOD
            ans2 -= sxy
            ans2 %= MOD

            ans_xy += (ans1 + ans2) % MOD
            ans_xy %= MOD

        for y in y_array:
            bit_count.add(y_map[y], 1)
            bit_y_sum.add(y_map[y], y)
            bit_x_sum.add(y_map[y], x)
            bit_xy_sum.add(y_map[y], (x * y) % MOD)
    
    answer = (2 * ans_xy) % MOD
    answer += ans_x
    answer %= MOD
    answer += ans_y
    answer %= MOD

    print(answer)








if __name__ == "__main__":
    main()
0