結果

問題 No.1300 Sum of Inversions
ユーザー DriceDrice
提出日時 2020-11-27 22:30:29
言語 C++14
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 221 ms / 2,000 ms
コード長 2,038 bytes
コンパイル時間 855 ms
コンパイル使用メモリ 47,048 KB
実行使用メモリ 11,864 KB
最終ジャッジ日時 2024-07-26 18:48:53
合計ジャッジ時間 7,222 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 34
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<cstdio>
#include<algorithm>
const long long mod = 998244353;
//bit[0]: cnt, bit[1]: sum
long long bit[2][200005];
long long a[200005];
long long b[200005];
long long c[200005]; // count
int map[200005];

long long ask(int p,long long bit[]){
    long long res = 0;
    while(p!=0){
        res += bit[p];
        if(res>=mod) res -= mod;
        p -= p&-p;
    }
    return res;
}

void change(int p, long long v, int n, long long bit[]){
    while(p<=n){
        bit[p] = (bit[p]+v)%mod;
        p += p&-p;
    }
}

int preWork(int n){
    std::sort(map+1,map+1+n);
    int  p = 1, size = 0;
    while(p<=n){
        int np = p;
        while(np+1<=n && map[np+1]==map[p]) np++;
        map[++size] = map[p];
        p = np+1;
    }
    return size;
}

int getId(int u,int n){
    int L = 1, R = n;
    int res = -1;
    while(L<=R){
        int M = (L+R)/2;
        if(map[M]<=u){
            res = M;
            L = M+1;
        }
        else R = M-1;
    }
    return res;
}

int main(){
    int n;
    scanf("%d",&n);
    for(int i = 1; i <= n; i++){ 
        scanf("%lld",&a[i]);
        map[i] = a[i];
    }
    int size = preWork(n);
    for(int i = n; i >= 1; i--){
        int upper = getId(a[i]-1,size);
        if(upper!=-1){
            long long cnt = ask(upper,bit[0]);
            long long sum = ask(upper,bit[1]);
            b[i] = (cnt*a[i]%mod+sum)%mod;
            c[i] = cnt;
        }
        int u = getId(a[i],size);
        change(u,1,n,bit[0]);
        change(u,a[i],n,bit[1]);
    }
    for(int i = 1; i <= n; i++) bit[0][i] = bit[1][i] = 0;
    long long ans = 0;
    for(int i = n; i >= 1; i--){
        int upper = getId(a[i]-1,size);
        if(upper!=-1){
            long long cnt = ask(upper,bit[0]);
            long long sum = ask(upper,bit[1]);
            long long cur = (cnt*a[i]%mod+sum)%mod;
            ans = (ans+cur)%mod;
        }
        int u = getId(a[i],size);
        change(u,c[i],n,bit[0]);
        change(u,b[i],n,bit[1]);
    }
    printf("%lld\n",ans);
    return 0;
}
0