結果
問題 | No.1300 Sum of Inversions |
ユーザー | None |
提出日時 | 2021-04-30 01:42:39 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 1,291 ms / 2,000 ms |
コード長 | 11,681 bytes |
コンパイル時間 | 159 ms |
コンパイル使用メモリ | 82,956 KB |
実行使用メモリ | 263,264 KB |
最終ジャッジ日時 | 2024-07-18 01:35:51 |
合計ジャッジ時間 | 32,748 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge2 |
(要ログイン)
テストケース
テストケース表示入力 | 結果 | 実行時間 実行使用メモリ |
---|---|---|
testcase_00 | AC | 44 ms
55,168 KB |
testcase_01 | AC | 43 ms
54,784 KB |
testcase_02 | AC | 45 ms
55,296 KB |
testcase_03 | AC | 1,102 ms
210,208 KB |
testcase_04 | AC | 1,017 ms
209,648 KB |
testcase_05 | AC | 819 ms
161,652 KB |
testcase_06 | AC | 1,208 ms
245,632 KB |
testcase_07 | AC | 1,151 ms
243,624 KB |
testcase_08 | AC | 1,251 ms
255,112 KB |
testcase_09 | AC | 1,220 ms
255,180 KB |
testcase_10 | AC | 727 ms
150,588 KB |
testcase_11 | AC | 701 ms
151,420 KB |
testcase_12 | AC | 989 ms
209,368 KB |
testcase_13 | AC | 935 ms
208,460 KB |
testcase_14 | AC | 1,291 ms
263,264 KB |
testcase_15 | AC | 1,218 ms
255,200 KB |
testcase_16 | AC | 1,066 ms
227,580 KB |
testcase_17 | AC | 651 ms
143,124 KB |
testcase_18 | AC | 863 ms
152,416 KB |
testcase_19 | AC | 894 ms
173,004 KB |
testcase_20 | AC | 875 ms
173,632 KB |
testcase_21 | AC | 901 ms
174,140 KB |
testcase_22 | AC | 811 ms
161,348 KB |
testcase_23 | AC | 1,107 ms
245,348 KB |
testcase_24 | AC | 835 ms
161,508 KB |
testcase_25 | AC | 717 ms
151,836 KB |
testcase_26 | AC | 704 ms
151,696 KB |
testcase_27 | AC | 825 ms
160,928 KB |
testcase_28 | AC | 1,251 ms
260,780 KB |
testcase_29 | AC | 908 ms
173,656 KB |
testcase_30 | AC | 1,250 ms
255,912 KB |
testcase_31 | AC | 838 ms
161,392 KB |
testcase_32 | AC | 877 ms
162,316 KB |
testcase_33 | AC | 300 ms
104,064 KB |
testcase_34 | AC | 328 ms
103,148 KB |
testcase_35 | AC | 512 ms
251,560 KB |
testcase_36 | AC | 662 ms
262,764 KB |
ソースコード
class Bit: def __init__(self, n, array=[]): """ :param n: number of elements """ self.n = n self.tree = [0]*(n+1) self.depth = n.bit_length() - 1 for i, a in enumerate(array): self.add(i, a) def get(self,i): """ return summation of elements in [0,i) """ s = 0 i -= 1 while i >= 0: s += self.tree[i] i = (i & (i + 1) )- 1 return s def build(self, array): """ bulid BIT from array """ for i, a in enumerate(array): self.add(i, a) def add(self, i, x): """ add x to i-th element """ while i < self.n: self.tree[i] += x i |= i + 1 def get_range(self,i,j): """ return summation of elements in [i,j) """ if i == 0: return self.get(j) return self.get(j)-self.get(i) def lower_bound(self, x, equal=False): """ return tuple = (return maximum i s.t. a0+a1+...+ai < x (if not existing, -1 ) , a0+a1+...+ai ) if one wants to include equal (i.e., a0+a1+...+ai <= x), please set equal = True (Cation) We must assume that A_i>=0 """ sum_ = 0 pos = -1 # 1-indexed の時は pos = 0 if not equal: for i in range(self.depth, -1, -1): k = pos + (1 << i) if k < self.n and sum_ + self.tree[k] < x: # 1-indexed の時は k <= self.n sum_ += self.tree[k] pos += 1 << i if equal: for i in range(self.depth, -1, -1): k = pos + (1 << i) if k < self.n and sum_ + self.tree[k] <= x: # 1-indexed の時は k <= self.n sum_ += self.tree[k] pos += 1 << i return pos, sum_ def __getitem__(self, i): """ [a0, a1, a2, ...] """ if i<0: i=self.n+i return self.get_range(i,i+1) def __setitem__(self, i, x): self.add(i,x-self[i]) def __iter__(self): """ [a0, a1, a2, ...] """ for i in range(self.n): yield self.get_range(i,i+1) def __str__(self): text1 = " ".join(["element: "] + list(map(str, self))) text2 = " ".join(["cumsum(1-indexed): "]+list(str(self.get(i)) for i in range(1,self.n+1))) return "\n".join((text1, text2)) class SortedList: def __init__(self, n, A=[]): """ :param n: miximum value of A self.size: number of elements in BitSet """ self.n = n self.p = Bit(self.n + 1) self.size = 0 self.flip = 0 self.S = Bit(self.n + 1) for a in A: self.add(a) def add(self,x): self.p.add(x, 1) self.S.add(x, x) self.size += 1 def remove(self,x): self.p.add(x, -1) self.S.add(x, -x) self.size -= 1 def bisect_left(self,x): """ return bisect_left(sorted(B),x) """ if x <= self.n: return self.p.get(x) else: return self.size def bisect_right(self,x): """ return bisect_right(sorted(B),x) """ x += 1 if x <= self.n: return self.p.get(x) else: return self.size def flip_counter(self): return self.flip def count(self,x): return self.p[x] def count_range(self,l,r): """ return number of elements in open set [l,r)""" return self.bisect_left(r)-self.bisect_left(l) def get_range(self,l,r): Sl=self.S.get(l) if r <= self.n: Sr = self.S.get(r) else: Sr = self.S.get(self.n+1) return Sr-Sl def minimum(self,k=1): """ return k-th minimum value """ if k <= self.size: return self.p.lower_bound(k)[0] + 1 else: sys.stderr.write("minimum: list index out of range (k={0})\n".format(k)) def min(self): return self.minimum(1) def max(self): return self.p.lower_bound(self.size)[0] + 1 def upper_bound(self,x,equal=False): """ return maximum element lower than x """ k = self.p.get(x+equal) if k: return self.minimum(k) else: sys.stderr.write("upper_bound: no element smaller than {0} in this BitSet\n".format(x)) def lower_bound(self,x,equal=False): """ return minimum element greater than x """ k =self.p.get(x+1-equal)+1 if k <= self.size: return self.minimum(k) else: sys.stderr.write("lower_bound: no element larger than {0} in this BitSet\n".format(x)) def __getitem__(self, k): """ return k-th minimum element (0-indexed) B[k] = sorted(A)[k] """ if len(self)==0: sys.stderr.write("__getitem__: no elements exist in this BitSet\n") elif k >= len(self): sys.stderr.write("__getitem__: index (={0}) is larger than the maximum index (={1})\n".format(k,len(self)-1)) elif k >= 0: return self.minimum(k+1) else: sys.stderr.write("__getitem__: index (={0}) is negative \n".format(k)) def __len__(self): return self.size def __iter__(self): """ return sorted list """ for i in range(self.n+1): if self.p[i]: for _ in range(self.p[i]): yield i def __str__(self): """ return sorted list """ text1 = " ".join(list(map(str, self))) return "[" + text1 + "]" class SortedList2: """ if we need compress """ def __init__(self, data, A=[]): """ self.size: number of elements in BitSet """ self.data = sorted(list(set(data))) self.n = len(self.data) self.p = Bit(self.n + 1) self.S = Bit(self.n + 1) self.size = 0 self.code = {} self.decode = [] for i, b in enumerate(self.data): self.code[b] = i self.decode.append(b) for a in A: self.add(a) def add(self,x): self.p.add(self.code[x], 1) self.S.add(self.code[x], x) self.size += 1 def remove(self,x): self.p.add(self.code[x], -1) self.S.add(self.code[x], -x) self.size -= 1 def bisect_left(self,x): """ return bisect_left(sorted(B),x) """ if x in self.code.keys(): return self.p.get(self.code[x]) else: return self.p.get(bisect_right(self.data,x)) def bisect_right(self,x): """ return bisect_right(sorted(B),x) """ x += 1 if x in self.code.keys(): return self.p.get(self.code[x]) else: return self.p.get(bisect_right(self.data,x)) def count(self,x): return self.p[self.code[x]] def count_range(self,l,r): """ return number of elements in open set [l,r)""" return self.bisect_left(r)-self.bisect_left(l) def get_range(self,l,r): if l in self.code.keys(): Sl=self.S.get(self.code[l]) else: Sl=self.S.get(bisect_right(self.data,l)) if r in self.code.keys(): Sr=self.S.get(self.code[r]) else: Sr=self.S.get(bisect_right(self.data,r)) return Sr-Sl def minimum(self,k=1): """ return k-th minimum value """ if k <= self.size: return self.decode[self.p.lower_bound(k)[0] + 1] else: sys.stderr.write("minimum: list index out of range (k={0})\n".format(k)) def min(self): return self.minimum(1) def max(self): return self.decode[self.p.lower_bound(self.size)[0] + 1] def upper_bound(self,x,equal=False): """ return maximum element lower than x """ if x in self.code.keys(): y = self.code[x] + equal else: y = bisect_right(self.data, x) k = self.p.get(y) if k: return self.minimum(k) else: sys.stderr.write("upper_bound: no element smaller than {0} in this BitSet\n".format(x)) def lower_bound(self,x,equal=False): """ return minimum element greater than x """ if x in self.code.keys(): y = self.code[x] + 1 - equal else: y = bisect_left(self.data, x) k =self.p.get(y)+1 if k <= self.size: return self.minimum(k) else: sys.stderr.write("lower_bound: no element larger than {0} in this BitSet\n".format(x)) def nearest(self,x,k): """ return k-th nearest value to x """ if k>len(self): sys.stderr.write("nearest: k (= {0}) is larger than the size of this BitSet\n".format(k)) return def test(d): r=self.bisect_right(x+d)-1 l=self.bisect_left(x-d) return r-l+1<=k ok,ng=0,10**18+1 while abs(ok-ng)>1: mid=(ok+ng)//2 if test(mid): ok=mid else: ng=mid d=ok r=self.bisect_right(x+d)-1 l=self.bisect_left(x-d) if d==0: R=self.lower_bound(x,equal=True) L=self.upper_bound(x,equal=True) if abs(x-L)==abs(R-x): if self.count(L)>=k: return L else: return R elif abs(x-L)<abs(R-x): return L else: return R elif r-l+1==k: R=self[r] L=self[l] if abs(x-L)<=abs(R-x): return R else: return L else: if l<=0: return self[r+1] elif r>=len(self)-1: return self[l-1] else: R=self[r+1] L=self[l-1] if abs(x-L)==abs(R-x): if self.count(L)>=k-(r-l+1): return L else: return R elif abs(x-L)<abs(R-x): return L else: return R def __getitem__(self, k): """ return k-th minimum element (0-indexed) B[k] = sorted(A)[k] """ if len(self)==0: sys.stderr.write("__getitem__: no elements exist in this BitSet\n") elif k >= len(self): sys.stderr.write("__getitem__: index (={0}) is larger than the maximum index (={1})\n".format(k,len(self)-1)) elif k >= 0: return self.minimum(k+1) else: sys.stderr.write("__getitem__: index (={0}) is negative \n".format(k)) def __len__(self): return self.size def __iter__(self): """ return sorted list """ for i in range(self.n+1): if self.p[i]: for _ in range(self.p[i]): yield self.decode[i] def __str__(self): """ return sorted list """ text1 = " ".join(list(map(str, self))) return "[" + text1 + "]" ################################################### def example(): global input example = iter( """ 6 10 5 1 5 10 7 """ .strip().split("\n")) input = lambda: next(example) ################################################### import sys input = sys.stdin.readline from bisect import bisect_left, bisect_right MOD=998244353 N=int(input()) A=list(map(int, input().split())) L = SortedList2(A) R = SortedList2(A,A) res=0 M=max(A) for a in A: R.remove(a) l=len(L)-L.bisect_right(a) Sl=L.get_range(a+1,M+1) r=R.bisect_left(a) Sr=R.get_range(0,a) res+=a*l*r+Sl*r+Sr*l res%=MOD L.add(a) print(res)