import bisect MOD = 998244353 def main(): import sys n, *rest = list(map(int, sys.stdin.read().split())) A = rest[:n] # Compute cnt_all_pairs and cnt_inversion_pairs using Fenwick Tree # Coordinate compression for all elements sorted_unique = sorted(list(set(A))) rank = {v:i for i, v in enumerate(sorted_unique)} m = len(sorted_unique) freq = [0] * m for v in A: idx = bisect.bisect_left(sorted_unique, v) freq[idx] += 1 # Compute cnt_all_pairs class FenwickTree: def __init__(self, size): self.n = size self.tree = [0] * (self.n + 1) def update(self, idx, delta): while idx <= self.n: self.tree[idx] += delta idx += idx & -idx def query(self, idx): res = 0 while idx > 0: res += self.tree[idx] idx -= idx & -idx return res def range_query(self, l, r): return self.query(r) - self.query(l - 1) ft_all = FenwickTree(m) for i in range(m): ft_all.update(i + 1, freq[i]) # 1-based indexing cnt_all_pairs = 0 for v in A: idx = bisect.bisect_right(sorted_unique, v) if idx < m: count = ft_all.range_query(idx + 1, m) # 1-based to m cnt_all_pairs += count cnt_all_pairs %= MOD # Compute cnt_inversion_pairs (normal inversion count) # Coordinate compression for the original array for inversion count sorted_A = sorted(list(set(A))) rank_inv = {v:i+1 for i, v in enumerate(sorted_A)} # 1-based max_rank = len(sorted_A) ft_inv = FenwickTree(max_rank) inv_count = 0 for j in reversed(range(n)): v_rank = bisect.bisect_left(sorted_A, A[j]) v_rank += 1 # 1-based inv_count += ft_inv.query(v_rank - 1) ft_inv.update(v_rank, 1) inv_count %= MOD # Precompute factorials and inverse factorials max_fact = n fact, inv_fact = [1]*(max_fact + 1), [1]*(max_fact + 1) for i in range(1, max_fact + 1): fact[i] = fact[i-1] * i % MOD inv_fact[max_fact] = pow(fact[max_fact], MOD-2, MOD) for i in range(max_fact-1, -1, -1): inv_fact[i] = inv_fact[i+1] * (i+1) % MOD def comb(n_, k_): if k_ < 0 or k_ > n_: return 0 return fact[n_] * inv_fact[k_] % MOD * inv_fact[n_ - k_] % MOD # Compute T = product of C(n,i) for i in 1..n T = 1 for i in range(1, n+1): c = comb(n, i) T = T * c % MOD # Compute S = sum_{k= 2: inv_n_times_n_minus_1 = pow(n_mod * ( (n-1) % MOD ) % MOD, MOD-2, MOD) else: inv_n_times_n_minus_1 = 0 # but for n<2, sum_k_part is 0, so term2 is 0 term1 = cnt_all_pairs * S_mod % MOD term1 = term1 * inv_n_square % MOD term2 = inv_count * sum_k_part % MOD term2 = term2 * inv_n_times_n_minus_1 % MOD ans = (term1 + term2) % MOD ans = ans * T % MOD print(ans) if __name__ == '__main__': main()