結果

問題 No.1300 Sum of Inversions
ユーザー monnumonnu
提出日時 2021-07-11 15:44:50
言語 C++14
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 234 ms / 2,000 ms
コード長 1,876 bytes
コンパイル時間 2,162 ms
コンパイル使用メモリ 182,536 KB
実行使用メモリ 18,816 KB
最終ジャッジ日時 2024-07-02 03:05:46
合計ジャッジ時間 10,197 ms
ジャッジサーバーID
(参考情報)
judge4 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 34
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
//#include <atcoder/all>
//using namespace atcoder;
using ll=long long;
using Graph=vector<vector<int>>;
#define MAX 1000000
//#define MOD 1000000007
#define MOD 998244353
//#define INF 1000000000
#define INF 1000000000000000000

class BIT{
  int n;
  vector<ll> a;
public:
  BIT(int n_):n(n_){
    a.resize(n+1,0);
  }
  void add(int i,ll x){
    while(i<=n){
      a[i]+=x;
      i+=i&(-i);
    }
  }
  ll sum(int i){
    ll ret=0;
    while(i>0){
      ret+=a[i];
      i-=i&(-i);
    }
    return ret;
  }
  ll sum(int l,int r){
    l--;
    return sum(r)-sum(l);
  }
};

int main(){
  int N;
  cin>>N;
  vector<ll> A(N);
  for(int i=0;i<N;i++){
    cin>>A[i];
  }
  vector<ll> nums=A;
  sort(nums.begin(),nums.end());
  nums.erase(unique(nums.begin(),nums.end()),nums.end());
  int n=nums.size();

  ll ans=0;

  BIT tree1(n);
  vector<ll> cnt1(N,0);
  for(int i=N-1;i>=0;i--){
    int k=lower_bound(nums.begin(),nums.end(),A[i])-nums.begin();
    tree1.add(k+1,1);
    cnt1[i]=tree1.sum(k);
  }
  vector<pair<ll,int>> a(N);
  for(int i=0;i<N;i++){
    a[i].first=A[i];
    a[i].second=i;
  }
  sort(a.begin(),a.end());
  BIT sum1(N);
  for(int j=0;j<N;j++){
    int i=a[j].second;
    ll x=sum1.sum(i+1,N);
    x%=MOD;
    ans+=A[i]*x%MOD;
    ans%=MOD;
    sum1.add(i+1,cnt1[i]);
  }
  //cout<<ans<<endl;

  BIT tree2(n);
  vector<ll> cnt2(N,0);
  for(int i=0;i<N;i++){
    int k=lower_bound(nums.begin(),nums.end(),A[i])-nums.begin();
    tree2.add(k+1,1);
    cnt2[i]=tree2.sum(k+2,n);
  }
  BIT sum2(N);
  for(int j=N-1;j>=0;j--){
    int i=a[j].second;
    ll x=sum2.sum(i);
    x%=MOD;
    ans+=A[i]*x%MOD;
    ans%=MOD;
    sum2.add(i+1,cnt2[i]);
  }
  //cout<<ans<<endl;
  for(int i=0;i<N;i++){
    //cout<<cnt1[i]<<" "<<cnt2[i]<<endl;
    ans+=(cnt1[i]*cnt2[i]%MOD)*A[i]%MOD;
    ans%=MOD;
  }
  cout<<ans<<endl;
}
0