結果
問題 |
No.2617 容量3のナップザック
|
ユーザー |
![]() |
提出日時 | 2024-01-26 22:20:21 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
TLE
|
実行時間 | - |
コード長 | 2,399 bytes |
コンパイル時間 | 2,317 ms |
コンパイル使用メモリ | 209,848 KB |
最終ジャッジ日時 | 2025-02-18 23:25:02 |
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 2 TLE * 1 -- * 37 |
ソースコード
#include <bits/stdc++.h> using namespace std; template <typename T> struct sum_set{ multiset<long long> st1, st2; int cnt = 0; long long sum = 0; sum_set(){ } void insert(T x){ st2.insert(x); if (cnt > 0){ if (*st1.begin() < *prev(st2.end())){ sum -= *st1.begin(); st2.insert(*st1.begin()); st1.erase(st1.begin()); sum += *prev(st2.end()); st1.insert(*prev(st2.end())); st2.erase(prev(st2.end())); } } } void erase(T x){ if (st2.count(x) == 1){ st2.erase(st2.find(x)); } else { sum -= x; st1.erase(st1.find(x)); st1.insert(*prev(st2.end())); sum += *prev(st2.end()); st2.erase(prev(st2.end())); } } void increment(){ cnt++; st1.insert(*prev(st2.end())); sum += *prev(st2.end()); st2.erase(prev(st2.end())); } }; int main(){ int N, K; int seed, a, b, m; cin >> N >> K >> seed >> a >> b >> m; vector<long long> f(N * 2); f[0] = seed; for (int i = 1; i < N * 2; i++){ f[i] = (a * f[i - 1] + b) % m; } vector<int> w(N); vector<long long> v(N); for (int i = 0; i < N; i++){ w[i] = f[i] % 3 + 1; v[i] = w[i] * f[N + i]; } vector<vector<long long>> V(3); for (int i = 0; i < N; i++){ V[w[i] - 1].push_back(v[i]); } for (int i = 0; i < 3; i++){ sort(V[i].begin(), V[i].end(), greater<long long>()); } auto get = [&](int i, int j) -> long long { if (j >= V[i].size()){ return 0; } else { return V[i][j]; } }; long long sum3 = 0; for (int i = 0; i < K; i++){ sum3 += get(2, i); } vector<sum_set<long long>> st(2); vector<long long> sum(2, 0); st[1].insert(get(0, 1) + get(0, 2)); st[1].insert(get(1, 0)); st[1].increment(); sum[1] += get(0, 0); long long ans = 0; for (int i = 0; i <= K; i++){ ans = max(ans, sum3 + st[i % 2].sum + sum[i % 2]); sum[i % 2] += get(0, i); sum[i % 2] += get(0, i + 1); st[i % 2].insert(get(0, i * 3) + get(0, i * 3 + 1)); st[i % 2].insert(get(0, i * 3 + 2) + get(0, i * 3 + 3)); st[i % 2].insert(get(0, i * 3 + 4) + get(0, i * 3 + 5)); st[i % 2].insert(get(1, i)); st[i % 2].insert(get(1, i + 1)); st[i % 2].erase(get(0, i) + get(0, i + 1)); st[i % 2].increment(); st[i % 2].increment(); if (i < K){ sum3 -= get(2, K - 1 - i); } } cout << ans << endl; }