class Bit: def __init__(self, n, array=[]): """ :param n: number of elements """ self.n = n self.tree = [0]*(n+1) self.depth = n.bit_length() - 1 for i, a in enumerate(array): self.add(i, a) def get(self,i): """ return summation of elements in [0,i) """ s = 0 i -= 1 while i >= 0: s += self.tree[i] i = (i & (i + 1) )- 1 return s def build(self, array): """ bulid BIT from array """ for i, a in enumerate(array): self.add(i, a) def add(self, i, x): """ add x to i-th element """ while i < self.n: self.tree[i] += x i |= i + 1 def get_range(self,i,j): """ return summation of elements in [i,j) """ if i == 0: return self.get(j) return self.get(j)-self.get(i) def lower_bound(self, x, equal=False): """ return tuple = (return maximum i s.t. a0+a1+...+ai < x (if not existing, -1 ) , a0+a1+...+ai ) if one wants to include equal (i.e., a0+a1+...+ai <= x), please set equal = True (Cation) We must assume that A_i>=0 """ sum_ = 0 pos = -1 # 1-indexed の時は pos = 0 if not equal: for i in range(self.depth, -1, -1): k = pos + (1 << i) if k < self.n and sum_ + self.tree[k] < x: # 1-indexed の時は k <= self.n sum_ += self.tree[k] pos += 1 << i if equal: for i in range(self.depth, -1, -1): k = pos + (1 << i) if k < self.n and sum_ + self.tree[k] <= x: # 1-indexed の時は k <= self.n sum_ += self.tree[k] pos += 1 << i return pos, sum_ def __getitem__(self, i): """ [a0, a1, a2, ...] """ if i<0: i=self.n+i return self.get_range(i,i+1) def __setitem__(self, i, x): self.add(i,x-self[i]) def __iter__(self): """ [a0, a1, a2, ...] """ for i in range(self.n): yield self.get_range(i,i+1) def __str__(self): text1 = " ".join(["element: "] + list(map(str, self))) text2 = " ".join(["cumsum(1-indexed): "]+list(str(self.get(i)) for i in range(1,self.n+1))) return "\n".join((text1, text2)) class BitImos: def __init__(self, n, array=[]): self.n = n self.p = Bit(self.n + 1) self.q = Bit(self.n + 1) for i, a in enumerate(array): self.add(i, a) def add(self, s, x): self.add_range(s,s+1,x) def add_range(self, s, t, x): """ add x to a close-interval [s,t)""" self.p.add(s, -x * s) self.p.add(t, x * t) self.q.add(s, x) self.q.add(t, -x) def build(self, array): """ bulid BIT from array """ for i, a in enumerate(array): self.add(i, a) def get(self,s): return self.get_range(s,s+1) def get_range(self,s,t): """ return summation of elements in [s,t) """ return self.p.get(t)+self.q.get(t)*t-self.p.get(s)-self.q.get(s)*s def __getitem__(self, s): """ return s-th element of array (not sum-array) """ return self.q.get(s+1) def __setitem__(self, i, x): self.add(i,x-self[i]) def __iter__(self): """ max(self) returns what we obtain by the Imos method""" for t in range(self.n): yield self.q.get(t+1) def __str__(self): text1 = " ".join(["element: "] + list(map(str, self))) return text1 class SortedList: def __init__(self, n, A=[]): """ :param n: miximum value of A self.size: number of elements in BitSet """ self.n = n self.p = Bit(self.n + 1) self.size = 0 self.flip = 0 for a in A: self.add(a) def add(self,x): self.p.add(x, 1) self.size += 1 self.flip += self.size - self.p.get(x+1) # we can remove this if we do not use flip_number def remove(self,x): self.p.add(x, -1) self.size -= 1 def bisect_left(self,x): """ return bisect_left(sorted(B),x) """ if x <= self.n: return self.p.get(x) else: return self.size def bisect_right(self,x): """ return bisect_right(sorted(B),x) """ x += 1 if x <= self.n: return self.p.get(x) else: return self.size def flip_counter(self): return self.flip def count(self,x): return self.p[x] def count_range(self,l,r): """ return number of elements in open set [l,r)""" return self.bisect_left(r)-self.bisect_left(l) def minimum(self,k=1): """ return k-th minimum value """ if k <= self.size: return self.p.lower_bound(k)[0] + 1 else: sys.stderr.write("minimum: list index out of range (k={0})\n".format(k)) def min(self): return self.minimum(1) def max(self): return self.p.lower_bound(self.size)[0] + 1 def upper_bound(self,x,equal=False): """ return maximum element lower than x """ k = self.p.get(x+equal) if k: return self.minimum(k) else: sys.stderr.write("upper_bound: no element smaller than {0} in this BitSet\n".format(x)) def lower_bound(self,x,equal=False): """ return minimum element greater than x """ k =self.p.get(x+1-equal)+1 if k <= self.size: return self.minimum(k) else: sys.stderr.write("lower_bound: no element larger than {0} in this BitSet\n".format(x)) def __getitem__(self, k): """ return k-th minimum element (0-indexed) B[k] = sorted(A)[k] """ if len(self)==0: sys.stderr.write("__getitem__: no elements exist in this BitSet\n") elif k >= len(self): sys.stderr.write("__getitem__: index (={0}) is larger than the maximum index (={1})\n".format(k,len(self)-1)) elif k >= 0: return self.minimum(k+1) else: sys.stderr.write("__getitem__: index (={0}) is negative \n".format(k)) def __len__(self): return self.size def __iter__(self): """ return sorted list """ for i in range(self.n+1): if self.p[i]: for _ in range(self.p[i]): yield i def __str__(self): """ return sorted list """ text1 = " ".join(list(map(str, self))) return "[" + text1 + "]" class SortedList2: """ if we need compress """ def __init__(self, data, A=[]): """ self.size: number of elements in BitSet """ self.data = sorted(list(set(data))) self.n = len(self.data) self.p = Bit(self.n + 1) self.size = 0 self.flip = 0 self.code = {} self.decode = [] for i, b in enumerate(self.data): self.code[b] = i self.decode.append(b) for a in A: self.add(a) def add(self,x): self.p.add(self.code[x], 1) self.size += 1 self.flip += self.size - self.p.get(self.code[x]+1) # we can remove this if we do not use flip_number def remove(self,x): self.p.add(self.code[x], -1) self.size -= 1 def bisect_left(self,x): """ return bisect_left(sorted(B),x) """ if x in self.code.keys(): return self.p.get(self.code[x]) else: return self.p.get(bisect_right(self.data,x)) def bisect_right(self,x): """ return bisect_right(sorted(B),x) """ x += 1 if x in self.code.keys(): return self.p.get(self.code[x]) else: return self.p.get(bisect_right(self.data,x)) def count(self,x): return self.p[self.code[x]] def count_range(self,l,r): """ return number of elements in open set [l,r)""" return self.bisect_left(r)-self.bisect_left(l) def minimum(self,k=1): """ return k-th minimum value """ if k <= self.size: return self.decode[self.p.lower_bound(k)[0] + 1] else: sys.stderr.write("minimum: list index out of range (k={0})\n".format(k)) def min(self): return self.minimum(1) def max(self): return self.decode[self.p.lower_bound(self.size)[0] + 1] def upper_bound(self,x,equal=False): """ return maximum element lower than x """ if x in self.code.keys(): y = self.code[x] + equal else: y = bisect_right(self.data, x) k = self.p.get(y) if k: return self.minimum(k) else: sys.stderr.write("upper_bound: no element smaller than {0} in this BitSet\n".format(x)) def lower_bound(self,x,equal=False): """ return minimum element greater than x """ if x in self.code.keys(): y = self.code[x] + 1 - equal else: y = bisect_left(self.data, x) k =self.p.get(y)+1 if k <= self.size: return self.minimum(k) else: sys.stderr.write("lower_bound: no element larger than {0} in this BitSet\n".format(x)) def nearest(self,x,k): """ return k-th nearest value to x """ if k>len(self): sys.stderr.write("nearest: k (= {0}) is larger than the size of this BitSet\n".format(k)) return def test(d): r=self.bisect_right(x+d)-1 l=self.bisect_left(x-d) return r-l+1<=k ok,ng=0,10**18+1 while abs(ok-ng)>1: mid=(ok+ng)//2 if test(mid): ok=mid else: ng=mid d=ok r=self.bisect_right(x+d)-1 l=self.bisect_left(x-d) if d==0: R=self.lower_bound(x,equal=True) L=self.upper_bound(x,equal=True) if abs(x-L)==abs(R-x): if self.count(L)>=k: return L else: return R elif abs(x-L)=len(self)-1: return self[l-1] else: R=self[r+1] L=self[l-1] if abs(x-L)==abs(R-x): if self.count(L)>=k-(r-l+1): return L else: return R elif abs(x-L)= len(self): sys.stderr.write("__getitem__: index (={0}) is larger than the maximum index (={1})\n".format(k,len(self)-1)) elif k >= 0: return self.minimum(k+1) else: sys.stderr.write("__getitem__: index (={0}) is negative \n".format(k)) def __len__(self): return self.size def __iter__(self): """ return sorted list """ for i in range(self.n+1): if self.p[i]: for _ in range(self.p[i]): yield self.decode[i] def __str__(self): """ return sorted list """ text1 = " ".join(list(map(str, self))) return "[" + text1 + "]" class LazySegmentTree(): def __init__(self, n, f, g, merge, ef, eh): self.n = n self.f = f self.g = lambda xh, x: g(xh, x) if xh != eh else x # self.g = g # 高速化したい場合 self.merge = merge self.ef = ef self.eh = eh l = (self.n - 1).bit_length() self.size = 1 << l self.tree = [self.ef] * (self.size << 1) self.lazy = [self.eh] * (self.size << 1) self.plt_cnt = 0 def build(self, array): """ bulid seg tree from array """ for i in range(self.n): self.tree[self.size + i] = array[i] for i in range(self.size - 1, 0, -1): self.tree[i] = self.f(self.tree[i<<1], self.tree[(i<<1)|1]) def replace(self,i,x): """ update (=replace) st[i] by x NOT update_range(i,i+1) """ i += self.size self.propagate_lazy(i) self.tree[i] = x self.lazy[i] = self.eh self.propagate_tree(i) def get(self, i): i += self.size self.propagate_lazy(i) return self.g(self.lazy[i], self.tree[i]) def update_range(self, l, r, x): """ act op(x, a) on elements a in [l, r) ( 0-indexed ) ( O(logN) ) """ l += self.size r += self.size l0 = l//(l&-l) r0 = r//(r&-r) self.propagate_lazy(l0) self.propagate_lazy(r0-1) while l < r: if r&1: r -= 1 self.lazy[r] = self.merge(x, self.lazy[r]) if l&1: self.lazy[l] = self.merge(x, self.lazy[l]) l += 1 l >>= 1 r >>= 1 self.propagate_tree(l0) self.propagate_tree(r0-1) def update_range_right_half(self, l, x): """ act op(x, a) on elements a in [l, N) ( 0-indexed ) ( O(logN) ) """ if l==0: self.update_all(x) return l += self.size l0 = l//(l&-l) self.propagate_lazy(l0) while l>1: if l&1: self.lazy[l] = self.merge(x, self.lazy[l]) l += 1 l >>= 1 self.propagate_tree(l0) def update_range_left_half(self, r, x): """ act op(x, a) on elements a in [0, r) ( 0-indexed ) ( O(logN) ) """ if r==self.n: self.update_all(x) return r += self.size r0 = r//(r&-r) self.propagate_lazy(r0-1) while r>1: if r&1: r -= 1 self.lazy[r] = self.merge(x, self.lazy[r]) r >>= 1 self.propagate_tree(r0-1) def update_all(self, x): self.lazy[1]=self.merge(x,self.lazy[1]) def get_range(self, l, r): """ get value from [l, r) (0-indexed) """ l += self.size r += self.size self.propagate_lazy(l//(l&-l)) self.propagate_lazy((r//(r&-r))-1) res_l = res_r = self.ef while l < r: if l & 1: res_l = self.f(res_l, self.g(self.lazy[l], self.tree[l])) l += 1 if r & 1: r -= 1 res_r = self.f(self.g(self.lazy[r], self.tree[r]), res_r) l >>= 1 r >>= 1 return self.f(res_l, res_r) def get_range_left_half(self, r): """ get value from [0, r) (0-indexed) """ if r==self.n: return self.get_all() r += self.size self.propagate_lazy((r//(r&-r))-1) res_l = res_r = self.ef while r>1: if r & 1: r -= 1 res_r = self.f(self.g(self.lazy[r], self.tree[r]), res_r) r >>= 1 return self.f(res_l, res_r) def get_range_right_half(self, l): """ get value from [l, N) (0-indexed) """ if l==0: return self.get_all() l += self.size self.propagate_lazy(l//(l&-l)) res_l = res_r = self.ef while l>1: if l & 1: res_l = self.f(res_l, self.g(self.lazy[l], self.tree[l])) l += 1 l >>= 1 return self.f(res_l, res_r) def get_all(self): return self.g(self.lazy[1], self.tree[1]) def max_right(self,l,func): """ return r such that ・r = l or f(op(a[l], a[l + 1], ..., a[r - 1])) = true ・r = n or f(op(a[l], a[l + 1], ..., a[r])) = false """ if l >= self.n: return self.n l += self.size s = self.ef while 1: while l % 2 == 0: l >>= 1 if not func(self.f(s,self.g(self.lazy[l],self.tree[l]))): while l < self.size: l<<=1 if func(self.f(s,self.g(self.lazy[l],self.tree[l]))): s = self.f(s, self.g(self.lazy[l], self.tree[l])) l += 1 return l - self.size s = self.f(s, self.g(self.lazy[l], self.tree[l])) l += 1 if l & -l == l: break return self.n def min_left(self,r,func): """ return l such that ・l = r or f(op(a[l], a[l + 1], ..., a[r - 1])) = true ・l = 0 or f(op(a[l - 1], a[l], ..., a[r - 1])) = false """ if r <= 0: return 0 r += self.size s = self.ef while 1: r -= 1 while r > 1 and r % 2: r >>= 1 if not func(self.f(self.g(self.lazy[r],self.tree[r]),s)): while r < self.size: r = (r<<1)|1 if func(self.f(self.g(self.lazy[r],self.tree[r]),s)): s = self.f(self.g(self.lazy[r], self.tree[r]), s) r -= 1 return r + 1 - self.size s = self.f(self.g(self.lazy[r], self.tree[r]), s) if r & -r == r: break return 0 def propagate_lazy(self, i): for k in range(i.bit_length()-1,0,-1): x = i>>k laz = self.lazy[x] if laz == self.eh: continue self.lazy[(x<<1)|1] = self.merge(laz, self.lazy[(x << 1) | 1]) self.lazy[x<<1] = self.merge(laz, self.lazy[x << 1]) self.tree[x] = self.g(laz, self.tree[x]) self.lazy[x] = self.eh def propagate_tree(self, i): for _ in range(1,i.bit_length()): i>>=1 self.tree[i] = self.f(self.g(self.lazy[i<<1], self.tree[i<<1]), self.g(self.lazy[(i<<1)|1], self.tree[(i<<1)|1])) def __getitem__(self, i): if i<0: i=self.n+i return self.get(i) def __setitem__(self, i, value): if i<0: i=self.n+i self.replace(i,value) def __iter__(self): for x in range(1, self.size): if self.lazy[x] == self.eh: continue self.lazy[(x<<1)|1] = self.merge(self.lazy[x], self.lazy[(x << 1) | 1]) self.lazy[x<<1] = self.merge(self.lazy[x], self.lazy[x << 1]) self.tree[x] = self.g(self.lazy[x], self.tree[x]) self.lazy[x] = self.eh for xh, x in zip(self.lazy[self.size:self.size+self.n], self.tree[self.size:self.size+self.n]): yield self.g(xh,x) def __str__(self): return str(list(self)) def debug(self): def full_tree_pos(G): n = G.number_of_nodes() if n == 0: return {} pos = {0: (0.5, 0.9)} if n == 1: return pos i = 1 while not n >= 2 ** i or not n < 2 ** (i + 1): i+=1 height = i p_key, p_y, p_x = 0, 0.9, 0.5 l_child = True for i in range(height): for j in range(2 ** (i + 1)): if 2 ** (i + 1) + j - 1 < n: if l_child == True: pos[2 ** (i + 1) + j - 1] = (p_x - 0.2 / (i * i + 1), p_y - 0.1) G.add_edge(2 ** (i + 1) + j - 1, p_key) l_child = False else: pos[2 ** (i + 1) + j - 1] = (p_x + 0.2 / (i * i + 1), p_y - 0.1) l_child = True G.add_edge(2 ** (i + 1) + j - 1, p_key) p_key += 1 (p_x, p_y) = pos[p_key] return pos import networkx as nx import matplotlib.pyplot as plt A = self.tree[1:] G = nx.Graph() labels = {} for i, a in enumerate(A): G.add_node(i) labels[i] = a pos = full_tree_pos(G) nx.draw(G, pos=pos, with_labels=True, labels=labels, node_size=1000) plt.savefig("tree-{0}.png".format(self.plt_cnt)) plt.clf() A = self.lazy[1:-1] G = nx.Graph() labels = {} for i, a in enumerate(A): G.add_node(i) labels[i] = a pos = full_tree_pos(G) nx.draw(G, pos=pos, with_labels=True, labels=labels, node_size=1000) plt.savefig("lazy-{0}.png".format(self.plt_cnt)) plt.clf() self.plt_cnt += 1 class CompressSegment: def __init__(self,data): data.sort(key=lambda x:x[0]) self.X, self.A, self.Xc = [], [], dict() for i, d in enumerate(data): x, h = d self.X.append(x) self.A.append(h) self.Xc[x]=i def __call__(self, l, r): return bisect_left(self.X,l), bisect_left(self.X,r) def range(self, l, r): return bisect_left(self.X,l), bisect_left(self.X,r) def __getitem__(self, i): return self.A[i] def __iter__(self): for a in self.A: yield a ############################################################# import sys import os from io import BytesIO, IOBase BUFSIZE = 8192 class FastIO(IOBase): newlines = 0 def __init__(self, file): self._fd = file.fileno() self.buffer = BytesIO() self.writable = "x" in file.mode or "r" not in file.mode self.write = self.buffer.write if self.writable else None def read(self): while True: b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE)) if not b: break ptr = self.buffer.tell() self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr) self.newlines = 0 return self.buffer.read() def readline(self): while self.newlines == 0: b = os.read(self._fd, max(os.fstat(self._fd).st_size, BUFSIZE)) self.newlines = b.count(b"\n") + (not b) ptr = self.buffer.tell() self.buffer.seek(0, 2), self.buffer.write(b), self.buffer.seek(ptr) self.newlines -= 1 return self.buffer.readline() def flush(self): if self.writable: os.write(self._fd, self.buffer.getvalue()) self.buffer.truncate(0), self.buffer.seek(0) class IOWrapper(IOBase): def __init__(self, file): self.buffer = FastIO(file) self.flush = self.buffer.flush self.writable = self.buffer.writable self.write = lambda s: self.buffer.write(s.encode("ascii")) self.read = lambda: self.buffer.read().decode("ascii") self.readline = lambda: self.buffer.readline().decode("ascii") sys.stdin = IOWrapper(sys.stdin) input = sys.stdin.readline ############################################################################### def example(): global input example = iter( """ 9 0 1 2 3 4 5 6 7 8 """ .strip().split("\n")) input = lambda: next(example) ############################################################################### import sys input = sys.stdin.readline from bisect import bisect_left, bisect_right # example() MOD=998244353 N=int(input()) A=list(map(int, input().split())) L = SortedList2(A) R = SortedList2(A,A) res=0 for i in range(N): a=A[i] R.remove(a) l=len(L)-L.bisect_right(a) r=R.bisect_left(a) res+=l*r*a res%=MOD L.add(a) ####### 更新:アフィン変換 取得:加算 (高速) ######################################################## shift1=N.bit_length() # 配列の長さの取り得る値(座圧している場合の上限に注意) shift2=32 # アフィン変換の加算部分の取り得る値 mask1=(1<>shift1, x&mask1 y0, y1 = y>>shift1, y&mask1 return (((x0+y0)%MOD)<>shift2, a&mask2 b0, b1 = b>>shift2, b&mask2 return ((a0*b0%MOD)<>shift2, a&mask2 x0, x1 = x>>shift1, x&mask1 return (((a0*x0%MOD+a1*x1%MOD)%MOD)< x*b+c) # (how to get) # res = st.get_range(l, r)>>shift1 ################################################################################ st = LazySegmentTree(N, f, g, merge, ef, eh) st2 = LazySegmentTree(N, f, g, merge, ef, eh) decode=sorted(A) code={x:i for i,x in enumerate(decode)} for i in range(N): a=A[i] st[code[a]]+= 1 large, size =divmod(st.get_range_right_half(code[a]+1), 1<