結果
| 問題 |
No.2062 Sum of Subset mod 999630629
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2022-10-02 16:34:53 |
| 言語 | C++14 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 2,734 bytes |
| コンパイル時間 | 2,049 ms |
| コンパイル使用メモリ | 171,028 KB |
| 実行使用メモリ | 13,312 KB |
| 最終ジャッジ日時 | 2024-12-25 12:04:41 |
| 合計ジャッジ時間 | 6,270 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 WA * 2 |
| other | AC * 1 WA * 28 |
ソースコード
#include<bits/stdc++.h>
using namespace std;
#define LL long long
const int maxn = 2e5 + 2;
const LL mod = 998244353;
int n;
LL sg1[maxn << 2], sg2[maxn << 2];
int lz1[maxn << 2], lz2[maxn << 2];
LL a[maxn];
int idx[maxn];
LL fw[maxn];
inline int lowbit(int x) { return x & -x; }
void add(int x) {
while (x <= n) {
++fw[x];
x += lowbit(x);
}
}
LL tot(int x) {
LL ret = 0;
while (x) {
ret += fw[x];
x -= lowbit(x);
}
return ret;
}
bool cmp(int x, int y) {
return a[x] < a[y];
}
LL qpow(LL a,LL b){
LL ans=1;
while(b){
if(b&1)ans=ans*a%mod;
b>>=1;
a=a*a%mod;
}
return ans%mod;
}
LL pow2k[maxn], inv2k[maxn];
void init() {
pow2k[0] = 1;
for (int k = 1; k <= n; ++k) {
pow2k[k] = pow2k[k-1] * 2 % mod;
}
inv2k[0] = 1;
inv2k[1] = qpow(2, mod - 2);
for (int k = 2; k <= n; ++k) {
inv2k[k] = inv2k[k-1] * inv2k[1] % mod;
}
}
#define lc (u << 1)
#define rc (lc | 1)
inline void pushup(LL sg[], int u) {
sg[u] = (sg[lc] + sg[rc]) % mod;
}
inline void down(LL sg[], int lz[], int u, int v = 1) {
lz[u] += v;
sg[u] = sg[u] * pow2k[v] % mod;
}
inline void pushdown(LL sg[], int lz[], int u) {
if (lz[u] > 0) {
down(sg, lz, lc, lz[u]);
down(sg, lz, rc, lz[u]);
lz[u] = 0;
}
}
#define pushup() pushup(sg, u)
#define pushdown() pushdown(sg, lz, u)
#define down() down(sg, lz, u)
void build(LL sg[], int u = 1, int l = 1, int r = n + 1) {
if (l + 1 == r) {
sg[u] = 1;
} else {
int m = (l + r) >> 1;
build(sg, lc, l, m);
build(sg, rc, m, r);
pushup();
}
}
void update(LL sg[], int lz[], int ll, int rr, int u = 1, int l = 1, int r = n + 1) {
if (ll <= l && r <= rr) {
down();
} else {
int m = (l + r) >> 1;
pushdown();
if (ll < m) update(sg, lz, ll, rr, lc, l, m);
if (m < rr) update(sg, lz, ll, rr, rc, m, r);
pushup();
}
}
LL query(LL sg[], int lz[], int ll, int rr, int u = 1, int l = 1, int r = n + 1) {
if (ll <= l && r <= rr) {
return sg[u];
} else {
int m = (l + r) >> 1;
pushdown();
LL ret = 0;
if (ll < m) ret += query(sg, lz, ll, rr, lc, l, m);
if (m < rr) ret += query(sg, lz, ll, rr, rc, m, r);
pushup();
return ret % mod;
}
}
int main() {
cin.tie(0)->sync_with_stdio(false);
cin >> n;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
init();
build(sg1);
build(sg2);
for (int i = 0; i < n; ++i) {
idx[i] = i + 1;
}
sort(idx, idx + n, cmp);
LL ans = 0;
for (int p = 0; p < n; ++p) {
int id = idx[p];
LL ltot = tot(id);
LL rtot = tot(n) - ltot;
LL L = query(sg1, lz1, 1, id + 1) * inv2k[rtot] % mod;
LL R = query(sg2, lz2, id, n + 1) * inv2k[ltot] % mod;
ans = (ans + L * R % mod * a[id] % mod) % mod;
update(sg1, lz1, 1, id + 1);
update(sg2, lz2, id, n + 1);
add(id);
}
cout << ans << "\n";
return 0;
}