結果
問題 | No.1300 Sum of Inversions |
ユーザー | yakamoto |
提出日時 | 2020-11-27 22:12:07 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 105 ms / 2,000 ms |
コード長 | 6,982 bytes |
コンパイル時間 | 2,685 ms |
コンパイル使用メモリ | 213,000 KB |
最終ジャッジ日時 | 2025-01-16 07:33:41 |
ジャッジサーバーID (参考情報) |
judge5 / judge5 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 34 |
ソースコード
/*** code generated by JHelper* More info: https://github.com/AlexeyDmitriev/JHelper* @author*/#ifndef SOLUTION_COMMON_H#include <bits/stdc++.h>using namespace std;using ll = long long;using Pii = pair<int, int>;template<typename T> using V = vector<T>;using Vi = V<int>;#define _1 first#define _2 second#define all(x) x.begin(), x.end()#define pb push_back#define lb lower_bound#define amax(a, b) a = max(a, b)#define amin(a, b) a = min(a, b)#define tmax(_next, _prev, expr) if (_prev != INF) { auto prev = _prev; amax(_next, expr); }#define tmin(_next, _prev, expr) if (_prev != INF) { auto prev = _prev; amin(_next, expr); }#define dim2(a, b, init) vector(a, vector(b, init))#define dim3(a, b, c, init) vector(a, vector(b, vector(c, init)))#define dim4(a, b, c, d, init) vector(a, vector(b, vector(c, vector(d, init))))#ifndef M_PIstatic const double M_PI = acos(-1.0);#endif#ifdef MY_DEBUG# define DEBUG(x) xconst bool isDebug = true;#else# define DEBUG(x)const bool isDebug = false;#endiftemplate<class A, class B>std::ostream & operator <<(ostream &os, const pair<A, B> &p) {os << "(" << p._1 << "," << p._2 << ")";return os;}void __print(int x) {cerr << x;}void __print(long x) {cerr << x;}void __print(long long x) {cerr << x;}void __print(unsigned x) {cerr << x;}void __print(unsigned long x) {cerr << x;}void __print(unsigned long long x) {cerr << x;}void __print(float x) {cerr << x;}void __print(double x) {cerr << x;}void __print(long double x) {cerr << x;}void __print(char x) {cerr << '\'' << x << '\'';}void __print(const char *x) {cerr << '\"' << x << '\"';}void __print(const string &x) {cerr << '\"' << x << '\"';}void __print(bool x) {cerr << (x ? "true" : "false");}void __print(V<bool> x) {for (auto i : x) cerr << i;}template<typename T, typename V>void __print(const pair<T, V> &x) {cerr << '('; __print(x.first); cerr << ','; __print(x.second); cerr << ')';}template<typename T>void __print(const T &x) {int f = 0; cerr << '{'; for (auto const &i: x) cerr << (f++ ? "," : ""), __print(i); cerr << "}";}void _print() {cerr << "]\n";}template <typename T, typename... V>void _print(T t, V... v) {__print(t); if (sizeof...(v)) cerr << ", "; _print(v...);}#ifdef MY_DEBUG#define debug(x...) cerr << "[" << #x << "] = ["; _print(x)#else#define debug(x...)#endiftemplate<class T>string join(V<T> &A, string delimiter = " ") {ostringstream os;for (int i = 0; i < A.size(); ++i) {if (i > 0) os << delimiter;os << A[i];}return os.str();}template <typename T>istream& operator>>(istream& in, vector<T> &A) {for (int i = 0; i < A.size(); i++) {in >> A[i];}return in;}template <typename T = int>tuple<V<T>, V<T>> na2(istream& in, int N, int add = 0) {auto res = make_tuple(V<T>(N), V<T>(N));for (int i = 0; i < N; ++i) {in >> get<0>(res)[i] >> get<1>(res)[i];get<0>(res)[i] += add;get<1>(res)[i] += add;}return res;}template <typename T = int>V<V<T>> nm(istream& in, int N, int M, int add = 0) {auto res = dim2(N, M, 0);for (int i = 0; i < N; ++i) {in >> res[i];if (add) {for (auto &a : res[i]) {a += add;}}}return res;}template <typename T>inline T floorDiv(T num, T d) {if (num >= 0) {return num / d;} else {T res = num / d;if (num % d) --res;return res;}}template<typename T>inline T min2(T a, T b) {return min(a, b);}template<typename T>inline T max2(T a, T b) {return max(a, b);}#define SOLUTION_COMMON_H#endif //SOLUTION_COMMON_Htemplate<typename T = int>class BIT {const T zero = 0;int n;int N;V<T> bit;int calcN(int x) {int k = 1 << (31 - __builtin_clz(x));return k == x ? k : k << 1;}public:BIT(int n): n(n), N(calcN(n)), bit(N + 1, zero) {}void add(int i, T x) {i++;while(i <= N) {bit[i] = bit[i] + x;i += i & -i;}}/*** [l, r)*/T query(int l, int r) {return sumUntil(r) - sumUntil(l);}T get(int i) {return sumUntil(i + 1) - sumUntil(i);}T sumUntil(int i) {T ans = zero;while(i > 0) {ans += bit[i];i -= i & -i;}return ans;}int lower_bound(T x) {int k = N;int res = 0;while(k > 0) {if (res + k <= N && bit[res + k] < x) {x -= bit[res + k];res += k;}k /= 2;}return res;}};const int MOD = 998244353;#ifndef MInt_Htemplate <unsigned int MOD>class MInt {private:int v;public:MInt() : v(0) {}MInt(long long x) {v = x % MOD;if (v < 0) v += MOD;}MInt& operator +=(const MInt &that) {v += that.v;if (v >= MOD) v -= MOD;return *this;}MInt& operator -=(const MInt &that) {v -= that.v;if (v < 0) v += MOD;return *this;}MInt& operator *=(const MInt &that) {v = (long long)(v) * that.v % MOD;return *this;}MInt& operator ++(int) {*this += 1;return *this;}MInt& operator --(int) {*this -= 1;return *this;}friend MInt operator+(const MInt& a, const MInt& b) {return MInt(a) += b;}friend MInt operator-(const MInt& a, const MInt& b) {return MInt(a) -= b;}friend MInt operator*(const MInt& a, const MInt& b) {return MInt(a) *= b;}friend std::ostream& operator<<(std::ostream& out, const MInt &a) {out << a.v;return out;}};template<unsigned int MOD> void __print(MInt<MOD> x) {std::cerr << x;}#define MInt_H#endif //MInt_Husing mint = MInt<MOD>;class C {public:void solve(std::istream& in, std::ostream& out) {ios::sync_with_stdio(false);cin.tie(nullptr);int N;in >> N;V<ll> Aori(N);in >> Aori;V<Pii> A;for (int i = 0; i < N; ++i) {A.pb({Aori[i], i});}sort(all(A), greater<>());debug(A);V<pair<mint, mint>> lt(N), gt(N);BIT bitLt(N), bitGt(N);BIT<mint> bitLtSum(N), bitGtSum(N);for (const auto &i : A) {auto cntLt = bitLt.sumUntil(i._2);auto sumLt = bitLtSum.sumUntil(i._2);debug(i, cntLt, sumLt);lt[i._2]._1 = cntLt;lt[i._2]._2 = sumLt;bitLt.add(i._2, 1);bitLtSum.add(i._2, i._1);}sort(all(A));for (const auto &i : A) {auto cntGt = bitGt.sumUntil(N) - bitGt.sumUntil(i._2);auto sumGt = bitGtSum.sumUntil(N) - bitGtSum.sumUntil(i._2);debug(i, cntGt, sumGt);gt[i._2]._1 = cntGt;gt[i._2]._2 = sumGt;bitGt.add(i._2, 1);bitGtSum.add(i._2, i._1);}debug(lt, gt);mint ans = 0;for (int i = 0; i < N; ++i) {auto cnt = lt[i]._1*gt[i]._1;debug(i, cnt);ans += cnt*Aori[i] + lt[i]._2*gt[i]._1 + lt[i]._1*gt[i]._2;}out << ans << endl;}};int main() {C solver;std::istream& in(std::cin);std::ostream& out(std::cout);solver.solve(in, out);return 0;}