from collections import defaultdict import random class BIT(): #1<=x<=N def __init__(self, n): self.n = n self.data = [0]*(n+1) def to_sum(self, i): s = 0 while i > 0: s += self.data[i] i -= (i & -i) return s def add(self, i, x): while i <= self.n: self.data[i] += x i += (i & -i) def get(self, i, j): #sigma [i,j] return self.to_sum(j)-self.to_sum(i-1) def press(List): tmp_set = list(set(List)) tmp_set.sort() d = defaultdict(int) for i, v in enumerate(tmp_set): d[v] = i + 1 return d n = int(input()) a = list(map(int, input().split())) mod = 998244353 press_a = press(a) left_data_value = BIT(n + 10) left_data_cnt = BIT(n + 10) ans = 0 Rdata = [] for i, v in enumerate(a): pressed_index = press_a[v] left_data_value.add(pressed_index, v) left_data_cnt.add(pressed_index, 1) sumvalue, sumcnt = left_data_value.get( pressed_index + 1, n + 1), left_data_cnt.get(pressed_index+1, n+1) Rdata.append((sumvalue, sumcnt)) a.reverse() Ldata = [] right_data_value = BIT(n + 10) right_data_cnt = BIT(n + 10) for i, v in enumerate(a): pressed_index = press_a[v] right_data_value.add(pressed_index, v) right_data_cnt.add(pressed_index, 1) sumvalue, sumcnt = right_data_value.get( 1, pressed_index - 1), right_data_cnt.get(1, pressed_index - 1) Ldata.append((sumvalue, sumcnt)) Ldata.reverse() a.reverse() for i in range(n): lsumvalue, lsumcnt = Rdata[i] rsumvalue, rsumcnt = Ldata[i] ans += lsumvalue*rsumcnt + rsumvalue*lsumcnt + a[i]*lsumcnt*rsumcnt ans %= mod print(ans)