結果

問題 No.3509 Get More Money
コンテスト
ユーザー gojoxd
提出日時 2026-04-18 22:15:18
言語 C++23
(gcc 15.2.0 + boost 1.89.0)
コンパイル:
g++-15 -O2 -lm -std=c++23 -Wuninitialized -DONLINE_JUDGE -o a.out _filename_
実行:
./a.out
結果
WA  
実行時間 -
コード長 4,848 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 2,918 ms
コンパイル使用メモリ 346,948 KB
実行使用メモリ 21,120 KB
最終ジャッジ日時 2026-04-18 22:15:53
合計ジャッジ時間 30,191 ms
ジャッジサーバーID
(参考情報)
judge3_1 / judge2_1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample WA * 1
other WA * 60
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

#include <bits/stdc++.h>
 using namespace std;

 struct SegTree {
     struct Node { long long cnt=0, sum=0; };
     int n;
     vector<long long> val;
     vector<Node> st;
     SegTree(const vector<long long>& v): val(v) {
         int sz = val.size();
         n = 1;
         while (n < sz) n <<= 1;
         st.assign(2*n, {});
     }
     void add(int idx, long long c) {
         int i = idx + n;
         st[i].cnt += c;
         st[i].sum += c * val[idx];
         for (i >>= 1; i; i >>= 1) {
             st[i].cnt = st[i<<1].cnt + st[i<<1|1].cnt;
             st[i].sum = st[i<<1].sum + st[i<<1|1].sum;
         }
     }
     long long count_range(int l, int r, int i, int nl, int nr) {
         if (r < nl || nr < l) return 0;
         if (l <= nl && nr <= r) return st[i].cnt;
         int mid = (nl+nr)>>1;
         return count_range(l,r,i<<1,nl,mid) + count_range(l,r,i<<1|1,mid+1,nr);
     }
     long long count_range(int l, int r) {
         if (l > r) return 0;
         return count_range(l,r,1,0,n-1);
     }
     // remove k smallest in [l,r], return sum removed
     long long remove_k_smallest(int l, int r, long long k, int i, int nl, int nr) {
         if (k == 0 || r < nl || nr < l || st[i].cnt == 0) return 0;
         if (l <= nl && nr <= r) {
             if (st[i].cnt <= k) {
                 long long res = st[i].sum;
                 st[i].cnt = st[i].sum = 0;
                 return res;
             }
         }
         if (nl == nr) {
             long long take = min(k, st[i].cnt);
             st[i].cnt -= take;
             st[i].sum -= take * val[nl];
             return take * val[nl];
         }
         int mid = (nl+nr)>>1;
         long long res = 0;
         res += remove_k_smallest(l,r,k,i<<1,nl,mid);
         long long left_removed_cnt = count_range(l, min(r, mid));
         long long remain = k - left_removed_cnt;
         if (remain > 0) res += remove_k_smallest(l,r,remain,i<<1|1,mid+1,nr);
         st[i].cnt = st[i<<1].cnt + st[i<<1|1].cnt;
         st[i].sum = st[i<<1].sum + st[i<<1|1].sum;
         return res;
     }
     // remove k largest in [l,r], return sum removed
     long long remove_k_largest(int l, int r, long long k, int i, int nl, int nr) {
         if (k == 0 || r < nl || nr < l || st[i].cnt == 0) return 0;
         if (l <= nl && nr <= r) {
             if (st[i].cnt <= k) {
                 long long res = st[i].sum;
                 st[i].cnt = st[i].sum = 0;
                 return res;
             }
         }
         if (nl == nr) {
             long long take = min(k, st[i].cnt);
             st[i].cnt -= take;
             st[i].sum -= take * val[nl];
             return take * val[nl];
         }
         int mid = (nl+nr)>>1;
         long long res = 0;
         res += remove_k_largest(l,r,k,i<<1|1,mid+1,nr);
         long long right_removed_cnt = count_range(max(l, mid+1), r);
         long long remain = k - right_removed_cnt;
         if (remain > 0) res += remove_k_largest(l,r,remain,i<<1,nl,mid);
         st[i].cnt = st[i<<1].cnt + st[i<<1|1].cnt;
         st[i].sum = st[i<<1].sum + st[i<<1|1].sum;
         return res;
     }
 };

 int main() {
     ios::sync_with_stdio(false);
     cin.tie(nullptr);

     int T;
     cin >> T;
     while (T--) {
         int N;
         long long K;
         cin >> N >> K;
         vector<long long> A(N+1), B(N+1), C(N+1), D(N+1);
         for (int i=1;i<=N;i++) cin >> A[i];
         for (int i=1;i<=N;i++) cin >> B[i];
         for (int i=1;i<=N;i++) cin >> C[i];
         for (int i=1;i<=N;i++) cin >> D[i];

         vector<long long> vals;
         vals.reserve(N);
         for (int i=1;i<=N;i++) vals.push_back(C[i]);
         sort(vals.begin(), vals.end());
         vals.erase(unique(vals.begin(), vals.end()), vals.end());

         SegTree st(vals);
         long long total = 0;
         long long profit = 0;

         for (int i=N;i>=1;i--) {
             int idxC = lower_bound(vals.begin(), vals.end(), C[i]) - vals.begin();
             st.add(idxC, D[i]);
             total += D[i];

             if (total > K) {
                 long long excess = total - K;
                 st.remove_k_smallest(0, (int)vals.size()-1, excess, 1, 0, st.n-1);
                 total = K;
             }

             int pos = upper_bound(vals.begin(), vals.end(), A[i]) - vals.begin();
             if (pos < (int)vals.size()) {
                 long long avail = st.count_range(pos, (int)vals.size()-1);
                 long long take = min(B[i], avail);
                 if (take > 0) {
                     long long sumC = st.remove_k_largest(pos, (int)vals.size()-1, take, 1, 0, st.n-1);
                     profit += sumC - take * A[i];
                     total -= take;
                 }
             }
         }
         cout << profit << "\n";
     }
     return 0;
 }
0