結果

問題 No.1300 Sum of Inversions
ユーザー NoneNone
提出日時 2021-04-30 01:42:39
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,291 ms / 2,000 ms
コード長 11,681 bytes
コンパイル時間 159 ms
コンパイル使用メモリ 82,956 KB
実行使用メモリ 263,264 KB
最終ジャッジ日時 2024-07-18 01:35:51
合計ジャッジ時間 32,748 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 44 ms
55,168 KB
testcase_01 AC 43 ms
54,784 KB
testcase_02 AC 45 ms
55,296 KB
testcase_03 AC 1,102 ms
210,208 KB
testcase_04 AC 1,017 ms
209,648 KB
testcase_05 AC 819 ms
161,652 KB
testcase_06 AC 1,208 ms
245,632 KB
testcase_07 AC 1,151 ms
243,624 KB
testcase_08 AC 1,251 ms
255,112 KB
testcase_09 AC 1,220 ms
255,180 KB
testcase_10 AC 727 ms
150,588 KB
testcase_11 AC 701 ms
151,420 KB
testcase_12 AC 989 ms
209,368 KB
testcase_13 AC 935 ms
208,460 KB
testcase_14 AC 1,291 ms
263,264 KB
testcase_15 AC 1,218 ms
255,200 KB
testcase_16 AC 1,066 ms
227,580 KB
testcase_17 AC 651 ms
143,124 KB
testcase_18 AC 863 ms
152,416 KB
testcase_19 AC 894 ms
173,004 KB
testcase_20 AC 875 ms
173,632 KB
testcase_21 AC 901 ms
174,140 KB
testcase_22 AC 811 ms
161,348 KB
testcase_23 AC 1,107 ms
245,348 KB
testcase_24 AC 835 ms
161,508 KB
testcase_25 AC 717 ms
151,836 KB
testcase_26 AC 704 ms
151,696 KB
testcase_27 AC 825 ms
160,928 KB
testcase_28 AC 1,251 ms
260,780 KB
testcase_29 AC 908 ms
173,656 KB
testcase_30 AC 1,250 ms
255,912 KB
testcase_31 AC 838 ms
161,392 KB
testcase_32 AC 877 ms
162,316 KB
testcase_33 AC 300 ms
104,064 KB
testcase_34 AC 328 ms
103,148 KB
testcase_35 AC 512 ms
251,560 KB
testcase_36 AC 662 ms
262,764 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

class Bit:
    def __init__(self, n, array=[]):
        """
        :param n: number of elements
        """
        self.n = n
        self.tree = [0]*(n+1)
        self.depth = n.bit_length() - 1
        for i, a in enumerate(array):
            self.add(i, a)

    def get(self,i):
        """ return summation of elements in [0,i) """
        s = 0
        i -= 1
        while i >= 0:
            s += self.tree[i]
            i = (i & (i + 1) )- 1
        return s

    def build(self, array):
        """ bulid BIT from array """
        for i, a in enumerate(array):
            self.add(i, a)

    def add(self, i, x):
        """ add x to i-th element """
        while i < self.n:
            self.tree[i] += x
            i |= i + 1

    def get_range(self,i,j):
        """ return summation of elements in [i,j) """
        if i == 0:
            return self.get(j)
        return self.get(j)-self.get(i)

    def lower_bound(self, x, equal=False):
        """
        return tuple = (return maximum i s.t. a0+a1+...+ai < x (if not existing, -1 ) , a0+a1+...+ai )
        if one wants to include equal (i.e., a0+a1+...+ai <= x), please set equal = True
        (Cation) We must assume that A_i>=0
        """
        sum_ = 0
        pos = -1    # 1-indexed の時は pos = 0
        if not equal:
            for i in range(self.depth, -1, -1):
                k = pos + (1 << i)
                if k < self.n and sum_ + self.tree[k] < x:  # 1-indexed の時は k <= self.n
                    sum_ += self.tree[k]
                    pos += 1 << i
        if equal:
            for i in range(self.depth, -1, -1):
                k = pos + (1 << i)
                if k < self.n and sum_ + self.tree[k] <= x: # 1-indexed の時は k <= self.n
                    sum_ += self.tree[k]
                    pos += 1 << i
        return pos, sum_

    def __getitem__(self, i):
        """ [a0, a1, a2, ...] """
        if i<0: i=self.n+i
        return self.get_range(i,i+1)

    def __setitem__(self, i, x):
        self.add(i,x-self[i])

    def __iter__(self):
        """ [a0, a1, a2, ...] """
        for i in range(self.n):
            yield self.get_range(i,i+1)

    def __str__(self):
        text1 = " ".join(["element:            "] + list(map(str, self)))
        text2 = " ".join(["cumsum(1-indexed):  "]+list(str(self.get(i)) for i in range(1,self.n+1)))
        return "\n".join((text1, text2))

class SortedList:
    def __init__(self, n, A=[]):
        """
        :param n: miximum value of A
        self.size: number of elements in BitSet
        """
        self.n = n
        self.p = Bit(self.n + 1)
        self.size = 0
        self.flip = 0
        self.S = Bit(self.n + 1)
        for a in A:
            self.add(a)

    def add(self,x):
        self.p.add(x, 1)
        self.S.add(x, x)
        self.size += 1

    def remove(self,x):
        self.p.add(x, -1)
        self.S.add(x, -x)
        self.size -= 1

    def bisect_left(self,x):
        """ return bisect_left(sorted(B),x) """
        if x <= self.n:
            return self.p.get(x)
        else:
            return self.size

    def bisect_right(self,x):
        """ return bisect_right(sorted(B),x) """
        x += 1
        if x <= self.n:
            return self.p.get(x)
        else:
            return self.size

    def flip_counter(self):
        return self.flip

    def count(self,x):
        return self.p[x]

    def count_range(self,l,r):
        """ return number of elements in open set [l,r)"""
        return self.bisect_left(r)-self.bisect_left(l)

    def get_range(self,l,r):
        Sl=self.S.get(l)
        if r <= self.n:
            Sr = self.S.get(r)
        else:
            Sr = self.S.get(self.n+1)
        return Sr-Sl

    def minimum(self,k=1):
        """ return k-th minimum value """
        if k <= self.size:
            return self.p.lower_bound(k)[0] + 1
        else:
            sys.stderr.write("minimum: list index out of range (k={0})\n".format(k))

    def min(self):
        return self.minimum(1)

    def max(self):
        return self.p.lower_bound(self.size)[0] + 1

    def upper_bound(self,x,equal=False):
        """ return maximum element lower than x """
        k = self.p.get(x+equal)
        if k:
            return self.minimum(k)
        else:
            sys.stderr.write("upper_bound: no element smaller than {0} in this BitSet\n".format(x))

    def lower_bound(self,x,equal=False):
        """ return minimum element greater than x """
        k =self.p.get(x+1-equal)+1
        if k <= self.size:
            return self.minimum(k)
        else:
            sys.stderr.write("lower_bound: no element larger than {0} in this BitSet\n".format(x))

    def __getitem__(self, k):
        """
        return k-th minimum element (0-indexed)
        B[k] = sorted(A)[k]
        """
        if len(self)==0:
            sys.stderr.write("__getitem__: no elements exist in this BitSet\n")
        elif k >= len(self):
            sys.stderr.write("__getitem__: index (={0}) is larger than the maximum index (={1})\n".format(k,len(self)-1))
        elif k >= 0:
            return self.minimum(k+1)
        else:
            sys.stderr.write("__getitem__: index (={0}) is negative \n".format(k))

    def __len__(self):
        return self.size

    def __iter__(self):
        """ return sorted list """
        for i in range(self.n+1):
            if self.p[i]:
                for _ in range(self.p[i]):
                    yield i

    def __str__(self):
        """ return sorted list """
        text1 = " ".join(list(map(str, self)))
        return "[" + text1 + "]"

class SortedList2:
    """ if we need compress """
    def __init__(self, data, A=[]):
        """
        self.size: number of elements in BitSet
        """
        self.data = sorted(list(set(data)))
        self.n = len(self.data)
        self.p = Bit(self.n + 1)
        self.S = Bit(self.n + 1)
        self.size = 0
        self.code = {}
        self.decode = []
        for i, b in enumerate(self.data):
            self.code[b] = i
            self.decode.append(b)
        for a in A:
            self.add(a)

    def add(self,x):
        self.p.add(self.code[x], 1)
        self.S.add(self.code[x], x)
        self.size += 1

    def remove(self,x):
        self.p.add(self.code[x], -1)
        self.S.add(self.code[x], -x)
        self.size -= 1

    def bisect_left(self,x):
        """ return bisect_left(sorted(B),x) """
        if x in self.code.keys():
            return self.p.get(self.code[x])
        else:
            return self.p.get(bisect_right(self.data,x))

    def bisect_right(self,x):
        """ return bisect_right(sorted(B),x) """
        x += 1
        if x in self.code.keys():
            return self.p.get(self.code[x])
        else:
            return self.p.get(bisect_right(self.data,x))

    def count(self,x):
        return self.p[self.code[x]]

    def count_range(self,l,r):
        """ return number of elements in open set [l,r)"""
        return self.bisect_left(r)-self.bisect_left(l)

    def get_range(self,l,r):
        if l in self.code.keys():
            Sl=self.S.get(self.code[l])
        else:
            Sl=self.S.get(bisect_right(self.data,l))

        if r in self.code.keys():
            Sr=self.S.get(self.code[r])
        else:
            Sr=self.S.get(bisect_right(self.data,r))

        return Sr-Sl

    def minimum(self,k=1):
        """ return k-th minimum value """
        if k <= self.size:
            return self.decode[self.p.lower_bound(k)[0] + 1]
        else:
            sys.stderr.write("minimum: list index out of range (k={0})\n".format(k))

    def min(self):
        return self.minimum(1)

    def max(self):
        return self.decode[self.p.lower_bound(self.size)[0] + 1]

    def upper_bound(self,x,equal=False):
        """ return maximum element lower than x """
        if x in self.code.keys():
            y = self.code[x] + equal
        else:
            y = bisect_right(self.data, x)
        k = self.p.get(y)
        if k:
            return self.minimum(k)
        else:
            sys.stderr.write("upper_bound: no element smaller than {0} in this BitSet\n".format(x))

    def lower_bound(self,x,equal=False):
        """ return minimum element greater than x """
        if x in self.code.keys():
            y = self.code[x] + 1 - equal
        else:
            y = bisect_left(self.data, x)
        k =self.p.get(y)+1
        if k <= self.size:
            return self.minimum(k)
        else:
            sys.stderr.write("lower_bound: no element larger than {0} in this BitSet\n".format(x))

    def nearest(self,x,k):
        """ return k-th nearest value to x """
        if k>len(self):
            sys.stderr.write("nearest: k (= {0}) is larger than the size of this BitSet\n".format(k))
            return

        def test(d):
            r=self.bisect_right(x+d)-1
            l=self.bisect_left(x-d)
            return r-l+1<=k

        ok,ng=0,10**18+1
        while abs(ok-ng)>1:
            mid=(ok+ng)//2
            if test(mid):
                ok=mid
            else:
                ng=mid
        d=ok

        r=self.bisect_right(x+d)-1
        l=self.bisect_left(x-d)
        if d==0:
            R=self.lower_bound(x,equal=True)
            L=self.upper_bound(x,equal=True)
            if abs(x-L)==abs(R-x):
                if self.count(L)>=k: return L
                else: return R
            elif abs(x-L)<abs(R-x): return L
            else: return R
        elif r-l+1==k:
            R=self[r]
            L=self[l]
            if abs(x-L)<=abs(R-x): return R
            else: return L
        else:
            if l<=0: return self[r+1]
            elif r>=len(self)-1: return self[l-1]
            else:
                R=self[r+1]
                L=self[l-1]
                if abs(x-L)==abs(R-x):
                    if self.count(L)>=k-(r-l+1):
                        return L
                    else: return R
                elif abs(x-L)<abs(R-x): return L
                else: return R

    def __getitem__(self, k):
        """
        return k-th minimum element (0-indexed)
        B[k] = sorted(A)[k]
        """
        if len(self)==0:
            sys.stderr.write("__getitem__: no elements exist in this BitSet\n")
        elif k >= len(self):
            sys.stderr.write("__getitem__: index (={0}) is larger than the maximum index (={1})\n".format(k,len(self)-1))
        elif k >= 0:
            return self.minimum(k+1)
        else:
            sys.stderr.write("__getitem__: index (={0}) is negative \n".format(k))

    def __len__(self):
        return self.size

    def __iter__(self):
        """ return sorted list """
        for i in range(self.n+1):
            if self.p[i]:
                for _ in range(self.p[i]):
                    yield self.decode[i]

    def __str__(self):
        """ return sorted list """
        text1 = " ".join(list(map(str, self)))
        return "[" + text1 + "]"

###################################################
def example():
    global input
    example = iter(
        """
6
10 5 1 5 10 7
        """
            .strip().split("\n"))
    input = lambda: next(example)
###################################################

import sys
input = sys.stdin.readline
from bisect import bisect_left, bisect_right

MOD=998244353

N=int(input())
A=list(map(int, input().split()))
L = SortedList2(A)
R = SortedList2(A,A)
res=0
M=max(A)
for a in A:
    R.remove(a)
    l=len(L)-L.bisect_right(a)
    Sl=L.get_range(a+1,M+1)
    r=R.bisect_left(a)
    Sr=R.get_range(0,a)
    res+=a*l*r+Sl*r+Sr*l
    res%=MOD
    L.add(a)

print(res)
0