class WaveletMatrix: def __init__(self,V): self.n=len(V) self.lg=max(V).bit_length() self.ranks=[] self.accs=[] self.original_V=V V=list(V) for bit in range(self.lg-1,-1,-1): rank=[0]*(self.n+1) for i,v in enumerate(V): rank[i+1]=rank[i]+((v>>bit)&1) swp=[0]*self.n zero,one=0,self.n-rank[self.n] for v in V: if (v>>bit)&1: swp[one]=v one+=1 else: swp[zero]=v zero+=1 acc=[0]*(self.n+1) for i,v in enumerate(swp): acc[i+1]=acc[i]+v V=swp self.ranks.append(rank) self.accs.append(acc) self.accs.append([0]) for v in self.original_V: self.accs[-1].append(self.accs[-1][-1]+v) def access(self,i): return self.original_V[i] def rank(self,r,x): return self._range_freq(0,r,x) def rank_range(self,l,r,x): return self.rank(r,x)-self.rank(l,x) def quantile(self,l,r,k): res=0 for i in range(self.lg-1,-1,-1): rank=self.ranks[self.lg-1-i] ones=rank[r]-rank[l] zeros=(r-l)-ones if k self.lg: return r-l res=0 for i in range(self.lg-1,-1,-1): bit=(x>>i)&1 rank=self.ranks[self.lg-1-i] ones=rank[r]-rank[l] zeros=(r-l)-ones if bit: res+=zeros zero_sum=self.n-rank[self.n] l=zero_sum+rank[l] r=zero_sum+rank[r] else: l-=rank[l] r-=rank[r] return res def range_freq(self,left,right,lower,upper): return self._range_freq(left,right,upper)-self._range_freq(left,right,lower) def prev_value(self,left,right,upper): cnt=self._range_freq(left,right,upper) return self.quantile(left,right,cnt-1) if cnt>0 else None def next_value(self,left,right,lower): cnt=self._range_freq(left,right,lower) return self.quantile(left,right,cnt) if cnt>i)&1 rank=self.ranks[self.lg-1-i] acc=self.accs[self.lg-1-i] if bit: zero_sum=self.n-rank[self.n] l0=l-rank[l] r0=r-rank[r] res+=acc[r0]-acc[l0] l=zero_sum+rank[l] r=zero_sum+rank[r] else: l-=rank[l] r-=rank[r] return res def range_sum(self,left,right,lower,upper): return self._range_sum(left,right,upper)-self._range_sum(left,right,lower) MOD = 998244353 N = int(input()) A = list(map(int,input().split())) ans = 0 wm = WaveletMatrix(A) for i in range(1,N-1): left = wm.range_freq(0,i,A[i]+1,10**18) left_sum = wm.range_sum(0,i,A[i]+1,10**18) right = wm.range_freq(i+1,N,0,A[i]) right_sum = wm.range_sum(i+1,N,0,A[i]) ans += left * right * A[i] % MOD ans += left_sum * right % MOD ans += left * right_sum % MOD ans %= MOD print(ans)