結果

問題 No.1300 Sum of Inversions
ユーザー mkmkmkmkmkmkmkmk
提出日時 2020-12-21 18:04:59
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 4,373 bytes
コンパイル時間 1,987 ms
コンパイル使用メモリ 177,548 KB
実行使用メモリ 18,872 KB
最終ジャッジ日時 2023-10-21 11:45:17
合計ジャッジ時間 8,507 ms
ジャッジサーバーID
(参考情報)
judge11 / judge10
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
4,348 KB
testcase_01 AC 2 ms
4,348 KB
testcase_02 AC 2 ms
4,348 KB
testcase_03 WA -
testcase_04 WA -
testcase_05 AC 105 ms
13,072 KB
testcase_06 AC 150 ms
16,768 KB
testcase_07 WA -
testcase_08 AC 167 ms
17,560 KB
testcase_09 AC 161 ms
17,560 KB
testcase_10 WA -
testcase_11 WA -
testcase_12 WA -
testcase_13 AC 130 ms
14,920 KB
testcase_14 AC 176 ms
18,812 KB
testcase_15 AC 163 ms
17,296 KB
testcase_16 WA -
testcase_17 WA -
testcase_18 WA -
testcase_19 WA -
testcase_20 WA -
testcase_21 WA -
testcase_22 WA -
testcase_23 AC 152 ms
16,768 KB
testcase_24 WA -
testcase_25 WA -
testcase_26 WA -
testcase_27 WA -
testcase_28 WA -
testcase_29 AC 116 ms
13,864 KB
testcase_30 AC 160 ms
17,296 KB
testcase_31 AC 108 ms
13,072 KB
testcase_32 WA -
testcase_33 AC 43 ms
12,740 KB
testcase_34 AC 99 ms
12,740 KB
testcase_35 AC 114 ms
18,812 KB
testcase_36 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
typedef long long lint;
#define rep(i,n) for(lint (i)=0;(i)<(n);(i)++)
#define repp(i,m,n) for(lint (i)=(m);(i)<(n);(i)++)
#define repm(i,n) for(lint (i)=(n-1);(i)>=0;(i)--)
#define INF (1ll<<60)
#define all(x) (x).begin(),(x).end()
//const lint MOD =1000000007;
const lint MOD=998244353;
const lint MAX = 4000000;
using Graph =vector<vector<lint>>;
typedef pair<lint,lint> P;
typedef map<lint,lint> M;
#define chmax(x,y) x=max(x,y)
#define chmin(x,y) x=min(x,y)

 
lint fac[MAX], finv[MAX], inv[MAX];
 
void COMinit() 
{
    fac[0] = fac[1] = 1;
    finv[0] = finv[1] = 1;
    inv[1] = 1;
    for (lint i = 2; i < MAX; i++)
    {
        fac[i] = fac[i - 1] * i % MOD;
        inv[i] = MOD - inv[MOD % i] * (MOD / i) % MOD;
        finv[i] = finv[i - 1] * inv[i] % MOD;
    }
}
 
long long COM(lint n, lint k)
{
    if (n < k)
        return 0;
    if (n < 0 || k < 0)
        return 0;
    return fac[n] * (finv[k] * finv[n - k] % MOD) % MOD;
}
 
lint primary(lint num)
{
    if (num < 2) return 0;
    else if (num == 2) return 1;
    else if (num % 2 == 0) return 0;
 
    double sqrtNum = sqrt(num);
    for (int i = 3; i <= sqrtNum; i += 2)
    {
        if (num % i == 0)
        {
            return 0;
        }
    }
 
    return 1;
}
   long long modpow(long long a, long long n, long long mod) {
    long long res = 1;
    while (n > 0) {
        if (n & 1) res = res * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return res;
}
    lint lcm(lint a,lint b){
        return a/__gcd(a,b)*b;
    }
     lint gcd(lint a,lint b){
        return __gcd(a,b);
    } 
    class BIT {
public:
    //データの長さ
    lint n;
    //データの格納先
    vector<lint> a;
    //コンストラクタ
    BIT(lint n):n(n),a(n+1,0){}

    //a[i]にxを加算する
    void add(lint i,lint x){
        i++;
        if(i==0) return;
        for(lint k=i;k<=n;k+=(k & -k)){
            a[k]+=x;
        }
    }

    //a[i]+a[i+1]+…+a[j]を求める
    lint sum(lint i,lint j){
        return sum_sub(j)-sum_sub(i-1);
    }

    //a[0]+a[1]+…+a[i]を求める
    lint sum_sub(lint i){
        i++;
        lint s=0;
        if(i==0) return s;
        for(lint k=i;k>0;k-=(k & -k)){
            s+=a[k];
        }
        return s;
    }

    //a[0]+a[1]+…+a[i]>=xとなる最小のiを求める(任意のkでa[k]>=0が必要)
    lint lower_bound(lint x){
        if(x<=0){
            //xが0以下の場合は該当するものなし→0を返す
            return 0;
        }else{
            lint i=0;lint r=1;
            //最大としてありうる区間の長さを取得する
            //n以下の最小の二乗のべき(BITで管理する数列の区間で最大のもの)を求める
            while(r<n) r=r<<1;
            //区間の長さは調べるごとに半分になる
            for(int len=r;len>0;len=len>>1) {
                //その区間を採用する場合
                if(i+len<n && a[i+len]<x){
                    x-=a[i+len];
                    i+=len;
                }
            }
            return i;
        }
    }
};
    template <class T>
vector<T> press(vector<T> &x) {
	auto res = x;
	sort(res.begin(), res.end());
	res.erase(unique(res.begin(), res.end()), res.end());
	for(int i = 0; i < (int)x.size(); i++)
		x[i] = lower_bound(res.begin(), res.end(), x[i]) - res.begin();
	return res;
}

    int main(){
      lint n;
      cin>>n;
      vector<lint> a(n);
      rep(i,n)cin>>a[i];
      auto x=press(a);
      lint sz=x.size();
      BIT sum(sz+5);
      BIT num(sz+5);
      lint xx[n],l[n],y[n],r[n];
      rep(i,n){
          sum.add(a[i],x[a[i]]);
          num.add(a[i],1);
          xx[i]=sum.sum(a[i]+1,sz);
          xx[i]%=MOD;
          l[i]=num.sum(a[i]+1,sz);
      }
      {
      BIT sum(sz+5);
      BIT num(sz+5);
      repm(i,n){
          sum.add(a[i],x[a[i]]);
          num.add(a[i],1);
          y[i]=sum.sum(0,a[i]-1);
          y[i]%=MOD;
          r[i]=num.sum(0,a[i]-1);
      }
      }
      lint ans=0;
      rep(i,n){
          if(l[i]*r[i]==0)continue;
          ans+=xx[i]*r[i];
          ans%=MOD;
          ans+=y[i]*l[i];
          ans%=MOD;
          lint add=l[i]*r[i];
          add%=MOD;
          add*=x[a[i]];
          add%=MOD;
          ans+=add;
      }
        cout<<ans<<endl;
    }
      
0