結果

問題 No.1300 Sum of Inversions
コンテスト
ユーザー 回転
提出日時 2025-11-15 18:53:16
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 4,060 bytes
コンパイル時間 331 ms
コンパイル使用メモリ 82,848 KB
実行使用メモリ 211,712 KB
最終ジャッジ日時 2025-11-15 18:54:30
合計ジャッジ時間 65,501 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 21 TLE * 13
権限があれば一括ダウンロードができます

ソースコード

diff #

class WaveletMatrix:
    def __init__(self,V):
        self.n=len(V)
        self.lg=max(V).bit_length()
        self.ranks=[]
        self.accs=[]
        self.original_V=V
        V=list(V)
        for bit in range(self.lg-1,-1,-1):
            rank=[0]*(self.n+1)
            for i,v in enumerate(V):
                rank[i+1]=rank[i]+((v>>bit)&1)
                
            swp=[0]*self.n
            zero,one=0,self.n-rank[self.n]
            for v in V:
                if (v>>bit)&1:
                    swp[one]=v
                    one+=1
                else:
                    swp[zero]=v
                    zero+=1
            
            acc=[0]*(self.n+1)
            for i,v in enumerate(swp):
                acc[i+1]=acc[i]+v
            
            V=swp
            self.ranks.append(rank)
            self.accs.append(acc)
        self.accs.append([0])
        for v in self.original_V:
            self.accs[-1].append(self.accs[-1][-1]+v)
    def access(self,i):
        return self.original_V[i]
    def rank(self,r,x):
        if (x>>self.lg)&1:return 0
        for i in range(self.lg-1,-1,-1):
            bit=(x>>i)&1
            rank=self.ranks[i]
            if bit:
                zeros=self.n-rank[r]
                r=zeros+rank[r]
            else:
                r-=rank[r]
        return r
    def rank_range(self,l,r,x):
        return self.rank(r,x)-self.rank(l,x)
    def quantile(self,l,r,k):
        res=0
        for i in range(self.lg-1,-1,-1):
            rank=self.ranks[self.lg-1-i]
            ones=rank[r]-rank[l]
            zeros=(r-l)-ones
            if k<zeros:
                l-=rank[l]
                r-=rank[r]
            else:
                res|=1<<i
                k-=zeros
                zero_sum=self.n-rank[self.n]
                l=zero_sum+rank[l]
                r=zero_sum+rank[r]
        return res
    def _range_freq(self,l,r,x):
        if x.bit_length() > self.lg:
            return r-l
        res=0
        for i in range(self.lg-1,-1,-1):
            bit=(x>>i)&1
            rank=self.ranks[self.lg-1-i]
            ones=rank[r]-rank[l]
            zeros=(r-l)-ones
            if bit:
                res+=zeros
                zero_sum=self.n-rank[self.n]
                l=zero_sum+rank[l]
                r=zero_sum+rank[r]
            else:
                l-=rank[l]
                r-=rank[r]
        return res
    def range_freq(self,left,right,lower,upper):
        return self._range_freq(left,right,upper)-self._range_freq(left,right,lower)
    def prev_value(self,left,right,upper):
        cnt=self._range_freq(left,right,upper)
        return self.quantile(left,right,cnt-1) if cnt>0 else None
    def next_value(self,left,right,lower):
        cnt=self._range_freq(left,right,lower)
        return self.quantile(left,right,cnt) if cnt<right-left else None
    def _range_sum(self,l,r,x):
        if self.lg<x.bit_length():return self.accs[-1][r]-self.accs[-1][l]
        res=0
        MOD = 998244353
        for i in range(self.lg-1,-1,-1):
            bit=(x>>i)&1
            rank=self.ranks[self.lg-1-i]
            acc=self.accs[self.lg-1-i]
            if bit:
                zero_sum=self.n-rank[self.n]
                l0=l-rank[l]
                r0=r-rank[r]
                res+=acc[r0]-acc[l0]
                l=zero_sum+rank[l]
                r=zero_sum+rank[r]
            else:
                l-=rank[l]
                r-=rank[r]
            res %= MOD
        return res
    def range_sum(self,left,right,lower,upper):
        return self._range_sum(left,right,upper)-self._range_sum(left,right,lower)

MOD = 998244353
N = int(input())
A = list(map(int,input().split()))

ans = 0
wm = WaveletMatrix(A)
for i in range(1,N-1):
    left = wm.range_freq(0,i,A[i]+1,10**18)
    left_sum = wm.range_sum(0,i,A[i]+1,10**18)
    right = wm.range_freq(i+1,N,0,A[i])
    right_sum = wm.range_sum(i+1,N,0,A[i])
    
    ans += left * right * A[i] % MOD
    ans += left_sum * right % MOD
    ans += left * right_sum % MOD
    ans %= MOD
print(ans)
0