結果
| 問題 | No.1300 Sum of Inversions | 
| コンテスト | |
| ユーザー |  siman | 
| 提出日時 | 2022-03-19 03:37:16 | 
| 言語 | C++17(clang) (17.0.6 + boost 1.87.0) | 
| 結果 | 
                                AC
                                 
                             | 
| 実行時間 | 233 ms / 2,000 ms | 
| コード長 | 1,829 bytes | 
| コンパイル時間 | 3,817 ms | 
| コンパイル使用メモリ | 144,036 KB | 
| 実行使用メモリ | 16,232 KB | 
| 最終ジャッジ日時 | 2024-10-03 16:39:34 | 
| 合計ジャッジ時間 | 10,578 ms | 
| ジャッジサーバーID (参考情報) | judge2 / judge5 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 3 | 
| other | AC * 34 | 
ソースコード
#include <cassert>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <iomanip>
#include <climits>
#include <map>
#include <queue>
#include <set>
#include <cstring>
#include <vector>
using namespace std;
typedef long long ll;
const ll MOD = 998244353;
struct Node {
  int idx;
  ll value;
  Node(int idx = -1, ll value = -1) {
    this->idx = idx;
    this->value = value;
  }
  bool operator<(const Node &n) const {
    if (value == n.value) {
      return idx > n.idx;
    } else {
      return value > n.value;
    }
  }
};
class BinaryIndexTree {
  public:
    vector <ll> bit;
    int N;
    BinaryIndexTree(int n) {
      N = n;
      for (int i = 0; i <= N; ++i) {
        bit.push_back(0);
      }
    }
    ll sum(int i) {
      ll ret = 0;
      while (i > 0) {
        ret += bit[i];
        ret %= MOD;
        i -= i & -i;
      }
      return ret;
    }
    void add(int i, ll x) {
      while (i <= N) {
        bit[i] += x;
        bit[i] %= MOD;
        i += i & -i;
      }
    }
};
int main() {
  int N;
  cin >> N;
  vector<ll> A(N);
  vector<Node> nodes;
  for (int i = 0; i < N; ++i) {
    cin >> A[i];
    nodes.push_back(Node(i + 1, A[i]));
  }
  sort(nodes.begin(), nodes.end());
  BinaryIndexTree bit1_1(N + 1);
  BinaryIndexTree bit1_2(N + 1);
  BinaryIndexTree bit2_1(N + 1);
  BinaryIndexTree bit2_2(N + 1);
  ll ans = 0;
  for (Node &node : nodes) {
    ll cnt_1 = bit1_1.sum(node.idx - 1);
    ll cnt_2 = bit2_1.sum(node.idx - 1);
    ll sum_1 = bit1_2.sum(node.idx - 1);
    ll sum_2 = bit2_2.sum(node.idx - 1);
    bit1_1.add(node.idx, 1);
    bit1_2.add(node.idx, node.value);
    bit2_1.add(node.idx, cnt_1);
    bit2_2.add(node.idx, sum_1 + cnt_1 * node.value);
    ans += sum_2 + cnt_2 * node.value;
    ans %= MOD;
  }
  cout << ans << endl;
  return 0;
}
            
            
            
        