class WaveletMatrix: def __init__(self, V): self.n = len(V) self.lg = max(V).bit_length() if V else 0 if self.lg == 0: self.lg = 1 self.B = [0] * self.lg self.zeros = [0] * self.lg self.accs = [] self.original_V = V curr_V = list(V) # ビット列の構築 for i in range(self.lg): bit = self.lg - 1 - i b = 0 zero_arr = [] one_arr = [] z_app = zero_arr.append o_app = one_arr.append mask = 1 << bit for j, v in enumerate(curr_V): if v & mask: b |= (1 << j) o_app(v) else: z_app(v) self.B[i] = b self.zeros[i] = len(zero_arr) curr_V = zero_arr + one_arr acc = [0] * (self.n + 1) for j in range(self.n): acc[j+1] = acc[j] + curr_V[j] self.accs.append(acc) acc_orig = [0] * (self.n + 1) for j in range(self.n): acc_orig[j+1] = acc_orig[j] + V[j] self.accs_orig = acc_orig def access(self, i): return self.original_V[i] def rank(self, r, x): if (x >> self.lg) & 1: return 0 B = self.B zeros = self.zeros lg = self.lg for i in range(lg): bit = (x >> (lg - 1 - i)) & 1 b_val = B[i] r1 = (b_val & ((1 << r) - 1)).bit_count() if bit: r = zeros[i] + r1 else: r = r - r1 return r def rank_range(self, l, r, x): return self.rank(r, x) - self.rank(l, x) def quantile(self, l, r, k): res = 0 B = self.B zeros = self.zeros lg = self.lg for i in range(lg): b_val = B[i] l1 = (b_val & ((1 << l) - 1)).bit_count() r1 = (b_val & ((1 << r) - 1)).bit_count() ones = r1 - l1 z = (r - l) - ones if k < z: l = l - l1 r = r - r1 else: res |= (1 << (lg - 1 - i)) k -= z z_cnt = zeros[i] l = z_cnt + l1 r = z_cnt + r1 return res def _range_freq(self, l, r, x): if x.bit_length() > self.lg: return r - l res = 0 B = self.B zeros = self.zeros lg = self.lg for i in range(lg): if l == r: break bit = (x >> (lg - 1 - i)) & 1 b_val = B[i] # 多倍長整数を用いた O(1) での rank1 計算 l1 = (b_val & ((1 << l) - 1)).bit_count() r1 = (b_val & ((1 << r) - 1)).bit_count() l0 = l - l1 r0 = r - r1 if bit: res += r0 - l0 z = zeros[i] l = z + l1 r = z + r1 else: l = l0 r = r0 return res def range_freq(self, left, right, lower, upper): return self._range_freq(left, right, upper) - self._range_freq(left, right, lower) def prev_value(self, left, right, upper): cnt = self._range_freq(left, right, upper) return self.quantile(left, right, cnt - 1) if cnt > 0 else None def next_value(self, left, right, lower): cnt = self._range_freq(left, right, lower) return self.quantile(left, right, cnt) if cnt < right - left else None def _range_sum(self, l, r, x): if self.lg < x.bit_length(): return self.accs_orig[r] - self.accs_orig[l] res = 0 B = self.B zeros = self.zeros accs = self.accs lg = self.lg for i in range(lg): if l == r: break bit = (x >> (lg - 1 - i)) & 1 b_val = B[i] l1 = (b_val & ((1 << l) - 1)).bit_count() r1 = (b_val & ((1 << r) - 1)).bit_count() l0 = l - l1 r0 = r - r1 if bit: res += accs[i][r0] - accs[i][l0] z = zeros[i] l = z + l1 r = z + r1 else: l = l0 r = r0 return res def range_sum(self, left, right, lower, upper): return self._range_sum(left, right, upper) - self._range_sum(left, right, lower) def _build_distinct_wm(self): P = [0] * self.n last_pos = {} for i, v in enumerate(self.original_V): P[i] = last_pos.get(v, -1) + 1 last_pos[v] = i self._distinct_wm = WaveletMatrix(P) def range_distinct(self, left, right): """ 区間 [left, right) に含まれる要素の種類数を返す """ if not hasattr(self, "_distinct_wm"): self._build_distinct_wm() return self._distinct_wm.range_freq(left, right, 0, left + 1) def bottom_k_sum(self, l, r, k): """ 区間 [l, r) の中で小さい方から k 個の要素の和を返す """ if k <= 0: return 0 if k >= r - l: return self.accs_orig[r] - self.accs_orig[l] res = 0 val = 0 B = self.B zeros = self.zeros accs = self.accs lg = self.lg for i in range(lg): b_val = B[i] l1 = (b_val & ((1 << l) - 1)).bit_count() r1 = (b_val & ((1 << r) - 1)).bit_count() ones = r1 - l1 z = (r - l) - ones l0 = l - l1 r0 = r - r1 if k <= z: l = l0 r = r0 else: res += accs[i][r0] - accs[i][l0] k -= z val |= (1 << (lg - 1 - i)) z_cnt = zeros[i] l = z_cnt + l1 r = z_cnt + r1 res += k * val return res def top_k_sum(self, l, r, k): """ 区間 [l, r) の中で大きい方から k 個の要素の和を返す """ if k <= 0: return 0 length = r - l if k >= length: return self.accs_orig[r] - self.accs_orig[l] total_sum = self.accs_orig[r] - self.accs_orig[l] return total_sum - self.bottom_k_sum(l, r, length - k) MOD = 998244353 N = int(input()) A = list(map(int,input().split())) MAX = max(A) WM = WaveletMatrix(A) ans = 0 for i in range(1,N-1): ans += WM.range_sum(0,i,A[i]+1,MAX+1) * WM.range_freq(i+1,N,0,A[i]) ans += A[i] * WM.range_freq(0,i,A[i]+1,MAX+1) * WM.range_freq(i+1,N,0,A[i]) ans += WM.range_freq(0,i,A[i]+1,MAX+1) * WM.range_sum(i+1,N,0,A[i]) ans %= MOD print(ans)