class SegmentTree: def __init__(self, a): self.padding = 0 self.n = len(a) self.N = 2 ** (self.n-1).bit_length() self.seg_data = [self.padding]*(self.N-1) + a + [self.padding]*(self.N-self.n) for i in range(2*self.N-2, 0, -2): self.seg_data[(i-1)//2] = self.seg_data[i] + self.seg_data[i-1] def __len__(self): return self.n def __getitem__(self, i): return self.seg_data[self.N-1+i] def __setitem__(self, i, x): idx = self.N - 1 + i self.seg_data[idx] = x while idx: idx = (idx-1) // 2 self.seg_data[idx] = self.seg_data[2*idx+1] + self.seg_data[2*idx+2] def query(self, i, j): # [i, j) if i == j: return 0 else: idx1 = self.N - 1 + i idx2 = self.N - 2 + j # 閉区間にする result = self.padding while idx1 < idx2 + 1: if idx1&1 == 0: # idx1が偶数 result = result + self.seg_data[idx1] if idx2&1 == 1: # idx2が奇数 result = result + self.seg_data[idx2] idx2 -= 1 idx1 //= 2 idx2 = (idx2 - 1)//2 return result def kth_left_idx(self, fr, k): if self.query(0, fr+1) < k: return -1 remain = k now = fr + self.N - 1 while self.seg_data[now] < remain: if now % 2: remain -= self.seg_data[now] now -= 1 else: now = (now - 1) // 2 while now < self.N - 1: nl = 2*now + 1 nr = nl + 1 if self.seg_data[nr] < remain: remain -= self.seg_data[nr] now = nl else: now = nr return now - (self.N - 1) def kth_right_idx(self, fr, k): if self.query(fr, self.n) < k: return -1 remain = k now = fr + self.N - 1 while self.seg_data[now] < remain: if now % 2 == 0: remain -= self.seg_data[now] now += 1 else: now //= 2 while now < self.N - 1: nl = 2*now + 1 nr = nl + 1 if self.seg_data[nl] < remain: remain -= self.seg_data[nl] now = nr else: now = nl return now - (self.N - 1) N = int(input()) A = list(map(int, input().split())) B = sorted(set(A)) M = len(B) a2id = dict(zip(B, range(M))) C = SegmentTree([0]*M) S = SegmentTree([0]*M) mod = 998244353 cnt = [0] * N ans = [0] * N for i, ai in enumerate(A): idx = a2id[ai] c = C.query(idx+1, M) cnt[i] = c ans[i] = S.query(idx+1, M) C[idx] += 1 S[idx] += ai C = SegmentTree([0]*M) S = SegmentTree([0]*M) ANS = 0 for i in range(N-1, -1, -1): ai = A[i] idx = a2id[ai] c = C.query(0, idx) s = S.query(0, idx) ANS += c * cnt[i] * ai % mod ANS += c * ans[i] % mod ANS += cnt[i] * s % mod ANS %= mod C[idx] += 1 S[idx] += ai print(ANS)