結果

問題 No.1300 Sum of Inversions
コンテスト
ユーザー 回転
提出日時 2026-05-22 18:11:37
言語 PyPy3
(7.3.17)
コンパイル:
pypy3 -mpy_compile _filename_
実行:
pypy3 _filename_
結果
TLE  
実行時間 -
コード長 6,881 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 145 ms
コンパイル使用メモリ 85,120 KB
実行使用メモリ 252,932 KB
最終ジャッジ日時 2026-05-22 18:11:55
合計ジャッジ時間 4,357 ms
ジャッジサーバーID
(参考情報)
judge2_1 / judge3_0
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other TLE * 1 -- * 33
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

class WaveletMatrix:
    def __init__(self, V):
        self.n = len(V)
        self.lg = max(V).bit_length() if V else 0
        if self.lg == 0:
            self.lg = 1
            
        self.B = [0] * self.lg
        self.zeros = [0] * self.lg
        self.accs = []
        
        self.original_V = V
        curr_V = list(V)
        
        # ビット列の構築
        for i in range(self.lg):
            bit = self.lg - 1 - i
            b = 0
            
            zero_arr = []
            one_arr = []
            z_app = zero_arr.append
            o_app = one_arr.append
            
            mask = 1 << bit
            
            for j, v in enumerate(curr_V):
                if v & mask:
                    b |= (1 << j)
                    o_app(v)
                else:
                    z_app(v)
                    
            self.B[i] = b
            self.zeros[i] = len(zero_arr)
            curr_V = zero_arr + one_arr
            
            acc = [0] * (self.n + 1)
            for j in range(self.n):
                acc[j+1] = acc[j] + curr_V[j]
            self.accs.append(acc)
            
        acc_orig = [0] * (self.n + 1)
        for j in range(self.n):
            acc_orig[j+1] = acc_orig[j] + V[j]
        self.accs_orig = acc_orig

    def access(self, i):
        return self.original_V[i]

    def rank(self, r, x):
        if (x >> self.lg) & 1: return 0
        B = self.B
        zeros = self.zeros
        lg = self.lg
        for i in range(lg):
            bit = (x >> (lg - 1 - i)) & 1
            b_val = B[i]
            r1 = (b_val & ((1 << r) - 1)).bit_count()
            if bit:
                r = zeros[i] + r1
            else:
                r = r - r1
        return r

    def rank_range(self, l, r, x):
        return self.rank(r, x) - self.rank(l, x)

    def quantile(self, l, r, k):
        res = 0
        B = self.B
        zeros = self.zeros
        lg = self.lg
        for i in range(lg):
            b_val = B[i]
            l1 = (b_val & ((1 << l) - 1)).bit_count()
            r1 = (b_val & ((1 << r) - 1)).bit_count()
            
            ones = r1 - l1
            z = (r - l) - ones
            
            if k < z:
                l = l - l1
                r = r - r1
            else:
                res |= (1 << (lg - 1 - i))
                k -= z
                z_cnt = zeros[i]
                l = z_cnt + l1
                r = z_cnt + r1
        return res

    def _range_freq(self, l, r, x):
        if x.bit_length() > self.lg:
            return r - l
        res = 0
        B = self.B
        zeros = self.zeros
        lg = self.lg
        for i in range(lg):
            if l == r: break
            bit = (x >> (lg - 1 - i)) & 1
            b_val = B[i]
            # 多倍長整数を用いた O(1) での rank1 計算
            l1 = (b_val & ((1 << l) - 1)).bit_count()
            r1 = (b_val & ((1 << r) - 1)).bit_count()
            l0 = l - l1
            r0 = r - r1
            if bit:
                res += r0 - l0
                z = zeros[i]
                l = z + l1
                r = z + r1
            else:
                l = l0
                r = r0
        return res

    def range_freq(self, left, right, lower, upper):
        return self._range_freq(left, right, upper) - self._range_freq(left, right, lower)

    def prev_value(self, left, right, upper):
        cnt = self._range_freq(left, right, upper)
        return self.quantile(left, right, cnt - 1) if cnt > 0 else None

    def next_value(self, left, right, lower):
        cnt = self._range_freq(left, right, lower)
        return self.quantile(left, right, cnt) if cnt < right - left else None

    def _range_sum(self, l, r, x):
        if self.lg < x.bit_length():
            return self.accs_orig[r] - self.accs_orig[l]
        res = 0
        B = self.B
        zeros = self.zeros
        accs = self.accs
        lg = self.lg
        for i in range(lg):
            if l == r: break
            bit = (x >> (lg - 1 - i)) & 1
            b_val = B[i]
            l1 = (b_val & ((1 << l) - 1)).bit_count()
            r1 = (b_val & ((1 << r) - 1)).bit_count()
            l0 = l - l1
            r0 = r - r1
            if bit:
                res += accs[i][r0] - accs[i][l0]
                z = zeros[i]
                l = z + l1
                r = z + r1
            else:
                l = l0
                r = r0
        return res

    def range_sum(self, left, right, lower, upper):
        return self._range_sum(left, right, upper) - self._range_sum(left, right, lower)

    def _build_distinct_wm(self):
        P = [0] * self.n
        last_pos = {}
        for i, v in enumerate(self.original_V):
            P[i] = last_pos.get(v, -1) + 1
            last_pos[v] = i
        self._distinct_wm = WaveletMatrix(P)

    def range_distinct(self, left, right):
        """
        区間 [left, right) に含まれる要素の種類数を返す
        """
        if not hasattr(self, "_distinct_wm"):
            self._build_distinct_wm()
        return self._distinct_wm.range_freq(left, right, 0, left + 1)

    def bottom_k_sum(self, l, r, k):
        """
        区間 [l, r) の中で小さい方から k 個の要素の和を返す
        """
        if k <= 0: return 0
        if k >= r - l: return self.accs_orig[r] - self.accs_orig[l]
        res = 0
        val = 0
        B = self.B
        zeros = self.zeros
        accs = self.accs
        lg = self.lg
        for i in range(lg):
            b_val = B[i]
            l1 = (b_val & ((1 << l) - 1)).bit_count()
            r1 = (b_val & ((1 << r) - 1)).bit_count()
            ones = r1 - l1
            z = (r - l) - ones
            l0 = l - l1
            r0 = r - r1
            if k <= z:
                l = l0
                r = r0
            else:
                res += accs[i][r0] - accs[i][l0]
                k -= z
                val |= (1 << (lg - 1 - i))
                z_cnt = zeros[i]
                l = z_cnt + l1
                r = z_cnt + r1
        res += k * val
        return res

    def top_k_sum(self, l, r, k):
        """
        区間 [l, r) の中で大きい方から k 個の要素の和を返す
        """
        if k <= 0: return 0
        length = r - l
        if k >= length: return self.accs_orig[r] - self.accs_orig[l]
        total_sum = self.accs_orig[r] - self.accs_orig[l]
        return total_sum - self.bottom_k_sum(l, r, length - k)

MOD = 998244353
N = int(input())
A = list(map(int,input().split()))
MAX = max(A)

WM = WaveletMatrix(A)
ans = 0
for i in range(1,N-1):
    ans += WM.range_sum(0,i,A[i]+1,MAX+1) * WM.range_freq(i+1,N,0,A[i])
    ans += A[i] * WM.range_freq(0,i,A[i]+1,MAX+1) * WM.range_freq(i+1,N,0,A[i])
    ans += WM.range_freq(0,i,A[i]+1,MAX+1) * WM.range_sum(i+1,N,0,A[i])
    ans %= MOD
print(ans)
0