#include using namespace std; typedef long long lint; #define rep(i,n) for(lint (i)=0;(i)<(n);(i)++) #define repp(i,m,n) for(lint (i)=(m);(i)<(n);(i)++) #define repm(i,n) for(lint (i)=(n-1);(i)>=0;(i)--) #define INF (1ll<<60) #define all(x) (x).begin(),(x).end() //const lint MOD =1000000007; const lint MOD=998244353; const lint MAX = 4000000; using Graph =vector>; typedef pair P; typedef map M; #define chmax(x,y) x=max(x,y) #define chmin(x,y) x=min(x,y) lint fac[MAX], finv[MAX], inv[MAX]; void COMinit() { fac[0] = fac[1] = 1; finv[0] = finv[1] = 1; inv[1] = 1; for (lint i = 2; i < MAX; i++) { fac[i] = fac[i - 1] * i % MOD; inv[i] = MOD - inv[MOD % i] * (MOD / i) % MOD; finv[i] = finv[i - 1] * inv[i] % MOD; } } long long COM(lint n, lint k) { if (n < k) return 0; if (n < 0 || k < 0) return 0; return fac[n] * (finv[k] * finv[n - k] % MOD) % MOD; } lint primary(lint num) { if (num < 2) return 0; else if (num == 2) return 1; else if (num % 2 == 0) return 0; double sqrtNum = sqrt(num); for (int i = 3; i <= sqrtNum; i += 2) { if (num % i == 0) { return 0; } } return 1; } long long modpow(long long a, long long n, long long mod) { long long res = 1; while (n > 0) { if (n & 1) res = res * a % mod; a = a * a % mod; n >>= 1; } return res; } lint lcm(lint a,lint b){ return a/__gcd(a,b)*b; } lint gcd(lint a,lint b){ return __gcd(a,b); } class BIT { public: //データの長さ lint n; //データの格納先 vector a; //コンストラクタ BIT(lint n):n(n),a(n+1,0){} //a[i]にxを加算する void add(lint i,lint x){ i++; if(i==0) return; for(lint k=i;k<=n;k+=(k & -k)){ a[k]+=x; } } //a[i]+a[i+1]+…+a[j]を求める lint sum(lint i,lint j){ return sum_sub(j)-sum_sub(i-1); } //a[0]+a[1]+…+a[i]を求める lint sum_sub(lint i){ i++; lint s=0; if(i==0) return s; for(lint k=i;k>0;k-=(k & -k)){ s+=a[k]; } return s; } //a[0]+a[1]+…+a[i]>=xとなる最小のiを求める(任意のkでa[k]>=0が必要) lint lower_bound(lint x){ if(x<=0){ //xが0以下の場合は該当するものなし→0を返す return 0; }else{ lint i=0;lint r=1; //最大としてありうる区間の長さを取得する //n以下の最小の二乗のべき(BITで管理する数列の区間で最大のもの)を求める while(r0;len=len>>1) { //その区間を採用する場合 if(i+len vector press(vector &x) { auto res = x; sort(res.begin(), res.end()); res.erase(unique(res.begin(), res.end()), res.end()); for(int i = 0; i < (int)x.size(); i++) x[i] = lower_bound(res.begin(), res.end(), x[i]) - res.begin(); return res; } int main(){ lint n; cin>>n; vector a(n); rep(i,n)cin>>a[i]; auto x=press(a); lint sz=x.size(); BIT sum(sz+5); BIT num(sz+5); lint xx[n],l[n],y[n],r[n]; rep(i,n){ sum.add(a[i],x[a[i]]); num.add(a[i],1); xx[i]=sum.sum(a[i]+1,sz); xx[i]%=MOD; l[i]=num.sum(a[i]+1,sz); } { BIT sum(sz+5); BIT num(sz+5); repm(i,n){ sum.add(a[i],x[a[i]]); num.add(a[i],1); y[i]=sum.sum(0,a[i]-1); y[i]%=MOD; r[i]=num.sum(0,a[i]-1); } } lint ans=0; rep(i,n){ if(l[i]*r[i]==0)continue; ans+=xx[i]*r[i]; ans%=MOD; ans+=y[i]*l[i]; ans%=MOD; lint add=l[i]*r[i]; add%=MOD; add*=x[a[i]]; add%=MOD; ans+=add; } cout<