import sys input = sys.stdin.readline N = int(input()) a = list(map(int, input().split())) mod = 998244353 class FenwickTree: def __init__(self, n): self.n = n self.data = [0] * (n + 2) def sum(self, l, r): s = 0 while l > 0: s -= self.data[l] l -= l & -l while r > 0: s += self.data[r] r -= r & -r return s def add(self, i, x): i += 1 while i <= self.n: self.data[i] += x i += i & -i def lowerbound(self, s): x = 0 y = 0 for i in range(self.n.bit_length(), -1, -1): k = x + (1 << i) if k <= self.n and (y + self.data[k] < s): y += self.data[k] x += 1 << i return x + 1 fwk = FenwickTree(N) fwk2 = FenwickTree(N) sa = [(a[i], i) for i in range(N)] sa.sort(reverse = True) table = [0] * N table2 = [0] * N for x, i in sa: table[i] = fwk.sum(0, i) % mod table2[i] = fwk2.sum(0, i) fwk.add(i, x) fwk2.add(i, 1) #print(table, table2) fwk = FenwickTree(N) fwk2 = FenwickTree(N) res = 0 for x, i in sa: res += fwk.sum(0, i) + fwk2.sum(0, i) * x res %= mod fwk.add(i, table[i] + table2[i] * x) fwk2.add(i, table2[i]) print(res)