結果
| 問題 |
No.1300 Sum of Inversions
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2021-04-30 01:02:14 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 25,382 bytes |
| コンパイル時間 | 166 ms |
| コンパイル使用メモリ | 82,248 KB |
| 実行使用メモリ | 253,744 KB |
| 最終ジャッジ日時 | 2024-07-18 01:09:25 |
| 合計ジャッジ時間 | 66,712 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 20 TLE * 14 |
ソースコード
class Bit:
def __init__(self, n, array=[]):
"""
:param n: number of elements
"""
self.n = n
self.tree = [0]*(n+1)
self.depth = n.bit_length() - 1
for i, a in enumerate(array):
self.add(i, a)
def get(self,i):
""" return summation of elements in [0,i) """
s = 0
i -= 1
while i >= 0:
s += self.tree[i]
i = (i & (i + 1) )- 1
return s
def build(self, array):
""" bulid BIT from array """
for i, a in enumerate(array):
self.add(i, a)
def add(self, i, x):
""" add x to i-th element """
while i < self.n:
self.tree[i] += x
i |= i + 1
def get_range(self,i,j):
""" return summation of elements in [i,j) """
if i == 0:
return self.get(j)
return self.get(j)-self.get(i)
def lower_bound(self, x, equal=False):
"""
return tuple = (return maximum i s.t. a0+a1+...+ai < x (if not existing, -1 ) , a0+a1+...+ai )
if one wants to include equal (i.e., a0+a1+...+ai <= x), please set equal = True
(Cation) We must assume that A_i>=0
"""
sum_ = 0
pos = -1 # 1-indexed の時は pos = 0
if not equal:
for i in range(self.depth, -1, -1):
k = pos + (1 << i)
if k < self.n and sum_ + self.tree[k] < x: # 1-indexed の時は k <= self.n
sum_ += self.tree[k]
pos += 1 << i
if equal:
for i in range(self.depth, -1, -1):
k = pos + (1 << i)
if k < self.n and sum_ + self.tree[k] <= x: # 1-indexed の時は k <= self.n
sum_ += self.tree[k]
pos += 1 << i
return pos, sum_
def __getitem__(self, i):
""" [a0, a1, a2, ...] """
if i<0: i=self.n+i
return self.get_range(i,i+1)
def __setitem__(self, i, x):
self.add(i,x-self[i])
def __iter__(self):
""" [a0, a1, a2, ...] """
for i in range(self.n):
yield self.get_range(i,i+1)
def __str__(self):
text1 = " ".join(["element: "] + list(map(str, self)))
text2 = " ".join(["cumsum(1-indexed): "]+list(str(self.get(i)) for i in range(1,self.n+1)))
return "\n".join((text1, text2))
class BitImos:
def __init__(self, n, array=[]):
self.n = n
self.p = Bit(self.n + 1)
self.q = Bit(self.n + 1)
for i, a in enumerate(array):
self.add(i, a)
def add(self, s, x):
self.add_range(s,s+1,x)
def add_range(self, s, t, x):
""" add x to a close-interval [s,t)"""
self.p.add(s, -x * s)
self.p.add(t, x * t)
self.q.add(s, x)
self.q.add(t, -x)
def build(self, array):
""" bulid BIT from array """
for i, a in enumerate(array):
self.add(i, a)
def get(self,s):
return self.get_range(s,s+1)
def get_range(self,s,t):
""" return summation of elements in [s,t) """
return self.p.get(t)+self.q.get(t)*t-self.p.get(s)-self.q.get(s)*s
def __getitem__(self, s):
""" return s-th element of array (not sum-array) """
return self.q.get(s+1)
def __setitem__(self, i, x):
self.add(i,x-self[i])
def __iter__(self):
""" max(self) returns what we obtain by the Imos method"""
for t in range(self.n):
yield self.q.get(t+1)
def __str__(self):
text1 = " ".join(["element: "] + list(map(str, self)))
return text1
class SortedList:
def __init__(self, n, A=[]):
"""
:param n: miximum value of A
self.size: number of elements in BitSet
"""
self.n = n
self.p = Bit(self.n + 1)
self.size = 0
self.flip = 0
for a in A:
self.add(a)
def add(self,x):
self.p.add(x, 1)
self.size += 1
self.flip += self.size - self.p.get(x+1) # we can remove this if we do not use flip_number
def remove(self,x):
self.p.add(x, -1)
self.size -= 1
def bisect_left(self,x):
""" return bisect_left(sorted(B),x) """
if x <= self.n:
return self.p.get(x)
else:
return self.size
def bisect_right(self,x):
""" return bisect_right(sorted(B),x) """
x += 1
if x <= self.n:
return self.p.get(x)
else:
return self.size
def flip_counter(self):
return self.flip
def count(self,x):
return self.p[x]
def count_range(self,l,r):
""" return number of elements in open set [l,r)"""
return self.bisect_left(r)-self.bisect_left(l)
def minimum(self,k=1):
""" return k-th minimum value """
if k <= self.size:
return self.p.lower_bound(k)[0] + 1
else:
sys.stderr.write("minimum: list index out of range (k={0})\n".format(k))
def min(self):
return self.minimum(1)
def max(self):
return self.p.lower_bound(self.size)[0] + 1
def upper_bound(self,x,equal=False):
""" return maximum element lower than x """
k = self.p.get(x+equal)
if k:
return self.minimum(k)
else:
sys.stderr.write("upper_bound: no element smaller than {0} in this BitSet\n".format(x))
def lower_bound(self,x,equal=False):
""" return minimum element greater than x """
k =self.p.get(x+1-equal)+1
if k <= self.size:
return self.minimum(k)
else:
sys.stderr.write("lower_bound: no element larger than {0} in this BitSet\n".format(x))
def __getitem__(self, k):
"""
return k-th minimum element (0-indexed)
B[k] = sorted(A)[k]
"""
if len(self)==0:
sys.stderr.write("__getitem__: no elements exist in this BitSet\n")
elif k >= len(self):
sys.stderr.write("__getitem__: index (={0}) is larger than the maximum index (={1})\n".format(k,len(self)-1))
elif k >= 0:
return self.minimum(k+1)
else:
sys.stderr.write("__getitem__: index (={0}) is negative \n".format(k))
def __len__(self):
return self.size
def __iter__(self):
""" return sorted list """
for i in range(self.n+1):
if self.p[i]:
for _ in range(self.p[i]):
yield i
def __str__(self):
""" return sorted list """
text1 = " ".join(list(map(str, self)))
return "[" + text1 + "]"
class SortedList2:
""" if we need compress """
def __init__(self, data, A=[]):
"""
self.size: number of elements in BitSet
"""
self.data = sorted(list(set(data)))
self.n = len(self.data)
self.p = Bit(self.n + 1)
self.size = 0
self.flip = 0
self.code = {}
self.decode = []
for i, b in enumerate(self.data):
self.code[b] = i
self.decode.append(b)
for a in A:
self.add(a)
def add(self,x):
self.p.add(self.code[x], 1)
self.size += 1
self.flip += self.size - self.p.get(self.code[x]+1) # we can remove this if we do not use flip_number
def remove(self,x):
self.p.add(self.code[x], -1)
self.size -= 1
def bisect_left(self,x):
""" return bisect_left(sorted(B),x) """
if x in self.code.keys():
return self.p.get(self.code[x])
else:
return self.p.get(bisect_right(self.data,x))
def bisect_right(self,x):
""" return bisect_right(sorted(B),x) """
x += 1
if x in self.code.keys():
return self.p.get(self.code[x])
else:
return self.p.get(bisect_right(self.data,x))
def count(self,x):
return self.p[self.code[x]]
def count_range(self,l,r):
""" return number of elements in open set [l,r)"""
return self.bisect_left(r)-self.bisect_left(l)
def minimum(self,k=1):
""" return k-th minimum value """
if k <= self.size:
return self.decode[self.p.lower_bound(k)[0] + 1]
else:
sys.stderr.write("minimum: list index out of range (k={0})\n".format(k))
def min(self):
return self.minimum(1)
def max(self):
return self.decode[self.p.lower_bound(self.size)[0] + 1]
def upper_bound(self,x,equal=False):
""" return maximum element lower than x """
if x in self.code.keys():
y = self.code[x] + equal
else:
y = bisect_right(self.data, x)
k = self.p.get(y)
if k:
return self.minimum(k)
else:
sys.stderr.write("upper_bound: no element smaller than {0} in this BitSet\n".format(x))
def lower_bound(self,x,equal=False):
""" return minimum element greater than x """
if x in self.code.keys():
y = self.code[x] + 1 - equal
else:
y = bisect_left(self.data, x)
k =self.p.get(y)+1
if k <= self.size:
return self.minimum(k)
else:
sys.stderr.write("lower_bound: no element larger than {0} in this BitSet\n".format(x))
def nearest(self,x,k):
""" return k-th nearest value to x """
if k>len(self):
sys.stderr.write("nearest: k (= {0}) is larger than the size of this BitSet\n".format(k))
return
def test(d):
r=self.bisect_right(x+d)-1
l=self.bisect_left(x-d)
return r-l+1<=k
ok,ng=0,10**18+1
while abs(ok-ng)>1:
mid=(ok+ng)//2
if test(mid):
ok=mid
else:
ng=mid
d=ok
r=self.bisect_right(x+d)-1
l=self.bisect_left(x-d)
if d==0:
R=self.lower_bound(x,equal=True)
L=self.upper_bound(x,equal=True)
if abs(x-L)==abs(R-x):
if self.count(L)>=k: return L
else: return R
elif abs(x-L)<abs(R-x): return L
else: return R
elif r-l+1==k:
R=self[r]
L=self[l]
if abs(x-L)<=abs(R-x): return R
else: return L
else:
if l<=0: return self[r+1]
elif r>=len(self)-1: return self[l-1]
else:
R=self[r+1]
L=self[l-1]
if abs(x-L)==abs(R-x):
if self.count(L)>=k-(r-l+1):
return L
else: return R
elif abs(x-L)<abs(R-x): return L
else: return R
def __getitem__(self, k):
"""
return k-th minimum element (0-indexed)
B[k] = sorted(A)[k]
"""
if len(self)==0:
sys.stderr.write("__getitem__: no elements exist in this BitSet\n")
elif k >= len(self):
sys.stderr.write("__getitem__: index (={0}) is larger than the maximum index (={1})\n".format(k,len(self)-1))
elif k >= 0:
return self.minimum(k+1)
else:
sys.stderr.write("__getitem__: index (={0}) is negative \n".format(k))
def __len__(self):
return self.size
def __iter__(self):
""" return sorted list """
for i in range(self.n+1):
if self.p[i]:
for _ in range(self.p[i]):
yield self.decode[i]
def __str__(self):
""" return sorted list """
text1 = " ".join(list(map(str, self)))
return "[" + text1 + "]"
class LazySegmentTree():
def __init__(self, n, f, g, merge, ef, eh):
self.n = n
self.f = f
self.g = lambda xh, x: g(xh, x) if xh != eh else x
# self.g = g # 高速化したい場合
self.merge = merge
self.ef = ef
self.eh = eh
l = (self.n - 1).bit_length()
self.size = 1 << l
self.tree = [self.ef] * (self.size << 1)
self.lazy = [self.eh] * (self.size << 1)
self.plt_cnt = 0
def build(self, array):
"""
bulid seg tree from array
"""
for i in range(self.n):
self.tree[self.size + i] = array[i]
for i in range(self.size - 1, 0, -1):
self.tree[i] = self.f(self.tree[i<<1], self.tree[(i<<1)|1])
def replace(self,i,x):
"""
update (=replace) st[i] by x
NOT update_range(i,i+1)
"""
i += self.size
self.propagate_lazy(i)
self.tree[i] = x
self.lazy[i] = self.eh
self.propagate_tree(i)
def get(self, i):
i += self.size
self.propagate_lazy(i)
return self.g(self.lazy[i], self.tree[i])
def update_range(self, l, r, x):
"""
act op(x, a) on elements a in [l, r) ( 0-indexed ) ( O(logN) )
"""
l += self.size
r += self.size
l0 = l//(l&-l)
r0 = r//(r&-r)
self.propagate_lazy(l0)
self.propagate_lazy(r0-1)
while l < r:
if r&1:
r -= 1
self.lazy[r] = self.merge(x, self.lazy[r])
if l&1:
self.lazy[l] = self.merge(x, self.lazy[l])
l += 1
l >>= 1
r >>= 1
self.propagate_tree(l0)
self.propagate_tree(r0-1)
def update_range_right_half(self, l, x):
"""
act op(x, a) on elements a in [l, N) ( 0-indexed ) ( O(logN) )
"""
if l==0:
self.update_all(x)
return
l += self.size
l0 = l//(l&-l)
self.propagate_lazy(l0)
while l>1:
if l&1:
self.lazy[l] = self.merge(x, self.lazy[l])
l += 1
l >>= 1
self.propagate_tree(l0)
def update_range_left_half(self, r, x):
"""
act op(x, a) on elements a in [0, r) ( 0-indexed ) ( O(logN) )
"""
if r==self.n:
self.update_all(x)
return
r += self.size
r0 = r//(r&-r)
self.propagate_lazy(r0-1)
while r>1:
if r&1:
r -= 1
self.lazy[r] = self.merge(x, self.lazy[r])
r >>= 1
self.propagate_tree(r0-1)
def update_all(self, x):
self.lazy[1]=self.merge(x,self.lazy[1])
def get_range(self, l, r):
"""
get value from [l, r) (0-indexed)
"""
l += self.size
r += self.size
self.propagate_lazy(l//(l&-l))
self.propagate_lazy((r//(r&-r))-1)
res_l = res_r = self.ef
while l < r:
if l & 1:
res_l = self.f(res_l, self.g(self.lazy[l], self.tree[l]))
l += 1
if r & 1:
r -= 1
res_r = self.f(self.g(self.lazy[r], self.tree[r]), res_r)
l >>= 1
r >>= 1
return self.f(res_l, res_r)
def get_range_left_half(self, r):
"""
get value from [0, r) (0-indexed)
"""
if r==self.n:
return self.get_all()
r += self.size
self.propagate_lazy((r//(r&-r))-1)
res_l = res_r = self.ef
while r>1:
if r & 1:
r -= 1
res_r = self.f(self.g(self.lazy[r], self.tree[r]), res_r)
r >>= 1
return self.f(res_l, res_r)
def get_range_right_half(self, l):
"""
get value from [l, N) (0-indexed)
"""
if l==0:
return self.get_all()
l += self.size
self.propagate_lazy(l//(l&-l))
res_l = res_r = self.ef
while l>1:
if l & 1:
res_l = self.f(res_l, self.g(self.lazy[l], self.tree[l]))
l += 1
l >>= 1
return self.f(res_l, res_r)
def get_all(self):
return self.g(self.lazy[1], self.tree[1])
def max_right(self,l,func):
"""
return r such that
・r = l or f(op(a[l], a[l + 1], ..., a[r - 1])) = true
・r = n or f(op(a[l], a[l + 1], ..., a[r])) = false
"""
if l >= self.n: return self.n
l += self.size
s = self.ef
while 1:
while l % 2 == 0:
l >>= 1
if not func(self.f(s,self.g(self.lazy[l],self.tree[l]))):
while l < self.size:
l<<=1
if func(self.f(s,self.g(self.lazy[l],self.tree[l]))):
s = self.f(s, self.g(self.lazy[l], self.tree[l]))
l += 1
return l - self.size
s = self.f(s, self.g(self.lazy[l], self.tree[l]))
l += 1
if l & -l == l: break
return self.n
def min_left(self,r,func):
"""
return l such that
・l = r or f(op(a[l], a[l + 1], ..., a[r - 1])) = true
・l = 0 or f(op(a[l - 1], a[l], ..., a[r - 1])) = false
"""
if r <= 0: return 0
r += self.size
s = self.ef
while 1:
r -= 1
while r > 1 and r % 2:
r >>= 1
if not func(self.f(self.g(self.lazy[r],self.tree[r]),s)):
while r < self.size:
r = (r<<1)|1
if func(self.f(self.g(self.lazy[r],self.tree[r]),s)):
s = self.f(self.g(self.lazy[r], self.tree[r]), s)
r -= 1
return r + 1 - self.size
s = self.f(self.g(self.lazy[r], self.tree[r]), s)
if r & -r == r: break
return 0
def propagate_lazy(self, i):
for k in range(i.bit_length()-1,0,-1):
x = i>>k
laz = self.lazy[x]
if laz == self.eh:
continue
self.lazy[(x<<1)|1] = self.merge(laz, self.lazy[(x << 1) | 1])
self.lazy[x<<1] = self.merge(laz, self.lazy[x << 1])
self.tree[x] = self.g(laz, self.tree[x])
self.lazy[x] = self.eh
def propagate_tree(self, i):
for _ in range(1,i.bit_length()):
i>>=1
self.tree[i] = self.f(self.g(self.lazy[i<<1], self.tree[i<<1]), self.g(self.lazy[(i<<1)|1], self.tree[(i<<1)|1]))
def __getitem__(self, i):
if i<0: i=self.n+i
return self.get(i)
def __setitem__(self, i, value):
if i<0: i=self.n+i
self.replace(i,value)
def __iter__(self):
for x in range(1, self.size):
if self.lazy[x] == self.eh:
continue
self.lazy[(x<<1)|1] = self.merge(self.lazy[x], self.lazy[(x << 1) | 1])
self.lazy[x<<1] = self.merge(self.lazy[x], self.lazy[x << 1])
self.tree[x] = self.g(self.lazy[x], self.tree[x])
self.lazy[x] = self.eh
for xh, x in zip(self.lazy[self.size:self.size+self.n], self.tree[self.size:self.size+self.n]):
yield self.g(xh,x)
def __str__(self):
return str(list(self))
def debug(self):
def full_tree_pos(G):
n = G.number_of_nodes()
if n == 0: return {}
pos = {0: (0.5, 0.9)}
if n == 1: return pos
i = 1
while not n >= 2 ** i or not n < 2 ** (i + 1): i+=1
height = i
p_key, p_y, p_x = 0, 0.9, 0.5
l_child = True
for i in range(height):
for j in range(2 ** (i + 1)):
if 2 ** (i + 1) + j - 1 < n:
if l_child == True:
pos[2 ** (i + 1) + j - 1] = (p_x - 0.2 / (i * i + 1), p_y - 0.1)
G.add_edge(2 ** (i + 1) + j - 1, p_key)
l_child = False
else:
pos[2 ** (i + 1) + j - 1] = (p_x + 0.2 / (i * i + 1), p_y - 0.1)
l_child = True
G.add_edge(2 ** (i + 1) + j - 1, p_key)
p_key += 1
(p_x, p_y) = pos[p_key]
return pos
import networkx as nx
import matplotlib.pyplot as plt
A = self.tree[1:]
G = nx.Graph()
labels = {}
for i, a in enumerate(A):
G.add_node(i)
labels[i] = a
pos = full_tree_pos(G)
nx.draw(G, pos=pos, with_labels=True, labels=labels, node_size=1000)
plt.savefig("tree-{0}.png".format(self.plt_cnt))
plt.clf()
A = self.lazy[1:-1]
G = nx.Graph()
labels = {}
for i, a in enumerate(A):
G.add_node(i)
labels[i] = a
pos = full_tree_pos(G)
nx.draw(G, pos=pos, with_labels=True, labels=labels, node_size=1000)
plt.savefig("lazy-{0}.png".format(self.plt_cnt))
plt.clf()
self.plt_cnt += 1
class CompressSegment:
def __init__(self,data):
data.sort(key=lambda x:x[0])
self.X, self.A, self.Xc = [], [], dict()
for i, d in enumerate(data):
x, h = d
self.X.append(x)
self.A.append(h)
self.Xc[x]=i
def __call__(self, l, r):
return bisect_left(self.X,l), bisect_left(self.X,r)
def range(self, l, r):
return bisect_left(self.X,l), bisect_left(self.X,r)
def __getitem__(self, i):
return self.A[i]
def __iter__(self):
for a in self.A:
yield a
#############################################################
import sys
import os
from io import BytesIO, IOBase
BUFSIZE = 8192
class FastIO(IOBase):
newlines = 0
def __init__(self, file):
self._fd = file.fileno()
self.buffer = BytesIO()
self.writable = "x" in file.mode or "r" not in file.mode
self.write = self.buffer.write if self.writable else None
def read(self):
while True:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
if not b:
break
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines = 0
return self.buffer.read()
def readline(self):
while self.newlines == 0:
b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE))
self.newlines = b.count(b"\n") + (not b)
ptr = self.buffer.tell()
self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr)
self.newlines -= 1
return self.buffer.readline()
def flush(self):
if self.writable:
os.write(self._fd, self.buffer.getvalue())
self.buffer.truncate(0), self.buffer.seek(0)
class IOWrapper(IOBase):
def __init__(self, file):
self.buffer = FastIO(file)
self.flush = self.buffer.flush
self.writable = self.buffer.writable
self.write = lambda s: self.buffer.write(s.encode("ascii"))
self.read = lambda: self.buffer.read().decode("ascii")
self.readline = lambda: self.buffer.readline().decode("ascii")
sys.stdin = IOWrapper(sys.stdin)
input = sys.stdin.readline
###############################################################################
def example():
global input
example = iter(
"""
9
0 1 2 3 4 5 6 7 8
"""
.strip().split("\n"))
input = lambda: next(example)
###############################################################################
import sys
input = sys.stdin.readline
from bisect import bisect_left, bisect_right
# example()
MOD=998244353
N=int(input())
A=list(map(int, input().split()))
L = SortedList2(A)
R = SortedList2(A,A)
res=0
for i in range(N):
a=A[i]
R.remove(a)
l=len(L)-L.bisect_right(a)
r=R.bisect_left(a)
res+=l*r*a
res%=MOD
L.add(a)
####### 更新:加算 取得:加算 ########################################################
shift=20
mask=(1<<shift)-1
# get chain rule
def f(x,y):
x0, x1 = x>>shift, x&mask
y0, y1 = y>>shift, y&mask
z0 = x0+y0
z0%=MOD
z1 = x1+y1
return (z0<<shift)|z1
ef = 0
# merge of update
def merge(a,b):
return (a+b)%MOD
eh = 0
# update chain rule
def g(a,x):
x0, x1 = x>>shift, x&mask
return ((a*x1%MOD)<<shift)|x1
# (how to get)
# res, _ = divmod(st.get_range(l, r), off)
################################################################################
st = LazySegmentTree(N, f, g, merge, ef, eh)
st2 = LazySegmentTree(N, f, g, merge, ef, eh)
decode=sorted(A)
code={x:i for i,x in enumerate(decode)}
for i in range(N):
a=A[i]
k=code[a]
st[k]+= 1
large, size =divmod(st.get_range_right_half(k+1), 1<<shift)
x,y=divmod(st[k],1<<shift)
st[k]=(((x+size)%MOD)<<shift) + y
res+=a*large
res%=MOD
B=[]
for a in A:
B.append(-a)
decode=sorted(B)
code={x:i for i,x in enumerate(decode)}
for i in range(N)[::-1]:
a=-A[i]
k=code[a]
st2[k]+= 1
large, size =divmod(st2.get_range_right_half(k+1), 1<<shift)
x,y=divmod(st2[k],1<<shift)
st2[k]=(((x+size)%MOD)<<shift) + y
res-=a*large
res%=MOD
print(res)