結果
| 問題 |
No.1309 テスト
|
| コンテスト | |
| ユーザー |
QCFium
|
| 提出日時 | 2020-04-16 22:51:25 |
| 言語 | C++14 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
AC
|
| 実行時間 | 817 ms / 4,000 ms |
| コード長 | 12,343 bytes |
| コンパイル時間 | 1,959 ms |
| コンパイル使用メモリ | 181,016 KB |
| 実行使用メモリ | 6,944 KB |
| 最終ジャッジ日時 | 2024-09-13 03:24:36 |
| 合計ジャッジ時間 | 9,774 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 85 |
ソースコード
#include <bits/stdc++.h>
int ri() {
int n;
scanf("%d", &n);
return n;
}
typedef int64_t s64;
typedef uint64_t u64;
int n;
s64 max;
#define INF 1000000000000000000
namespace DP {
s64 run(s64 median, s64 mode) {
std::vector<std::vector<std::vector<s64> > > dp(n + 1,
std::vector<std::vector<s64> > (2 * max * n + 1, std::vector<s64> (n + 1, -INF)));
dp[0][max * n][0] = 0;
for (int i = 0; i <= max; i++) {
for (int j = n; j >= 0; j--) {
for (int k = 0; k <= 2 * max * n; k++) {
for (int l = 0; l <= n; l++) {
if (dp[j][k][l] == -INF) continue;
for (int m = 1; j + m <= n; m++) {
if (j <= n / 2 && j + m > n / 2 && i != median) continue;
int next_k = k;
if (m == l) next_k += (i - mode) * m;
else if (m > l) next_k = max * n + (i - mode) * m;
dp[j + m][next_k][std::max(l, m)] = std::max(dp[j + m][next_k][std::max(l, m)], dp[j][k][l] + i * m);
}
}
}
}
}
s64 res = -1;
for (int i = 1; i <= n; i++) res = std::max(res, dp[n][max * n][i]);
return res;
}
};
namespace Gu {
s64 calc(int mode_num, s64 mode_sum, int mode_freq, s64 min, s64 max, int all_num) {
assert(mode_freq > 1);
int other = all_num - mode_num * mode_freq;
assert(other >= 0);
if (!mode_num) {
if (!mode_sum) {
int other_block = other / (mode_freq - 1);
int other_leftover = other % (mode_freq - 1);
if (other_block + !!other_leftover > max - min + 1) return -INF;
return other_block * (max + max - other_block + 1) / 2 * (mode_freq - 1) + other_leftover * (max - other_block);
}
return -INF; // invalid
}
if (max - min + 1 < mode_num) return -INF;
s64 clearance = mode_sum - mode_num * (min + min + mode_num - 1) / 2;
if (clearance < 0) return -INF;
if (mode_num * (max + max - mode_num + 1) / 2 < mode_sum) return -INF;
int other_block = other / (mode_freq - 1);
int other_leftover = other % (mode_freq - 1);
if (other_block && !other_leftover) other_block--, other_leftover = mode_freq - 1;
if (other_block + !!other_leftover + mode_num > max - min + 1) return -INF;
int number_clearance = max - min + 1 - (other_block + !!other_leftover + mode_num);
int slide = clearance / mode_num - number_clearance;
s64 sum = mode_sum * mode_freq + other_block * (max + max - other_block + 1) / 2 * (mode_freq - 1)
+ other_leftover * (max - other_block);
if (slide > 0) {
sum -= (other_leftover + (s64) (mode_freq - 1) * (slide - 1)) * mode_num;
sum -= clearance % mode_num * (mode_freq - 1);
} else if (!slide) sum -= clearance % mode_num * other_leftover;
return sum;
}
s64 run(s64 median, s64 mode) {
int half = n / 2;
s64 upper = max - median;
s64 res = -1;
// freq == 1
if (median >= half && median + half <= max) {
s64 min_sum = median + (s64) (half - 1) * half / 2 + (s64) (median + 1 + median + half) * half / 2;
s64 max_sum = median + (s64) (median - 1 + median - half) * half / 2 + (s64) (max + max - half + 1) * half / 2;
if (mode * n >= min_sum && mode * n <= max_sum) res = mode * n;
}
for (int freq = 2; freq <= n; freq++) {
for (int left = 0; left * freq <= half && left <= median; left++) {
for (int right = 0; right * freq <= half && right <= upper; right++) {
int median_l_min = left * freq;
int median_l_max = std::min<s64>(half, left * freq + (median - left) * (freq - 1));
int median_r_max = n - right * freq;
int median_r_min = std::max<s64>(half + 1, n - right * freq - (upper - right) * (freq - 1));
assert(median_l_min <= median_l_max);
assert(median_r_min <= median_r_max);
// std::cerr << freq << " " << left << "," << right << std::endl;
if (left || right) { // don't use median as (one of) mode(s)
s64 mode_all_sum = mode * (left + right);
int median_r = median_r_min;
int median_l = std::max(median_l_min, median_r - (freq - 1));
if (median_l <= half) {
// if (freq == 3 && left == 0 && right == 1) std::cerr << "yay:" << median_l_max << std::endl;
assert(median_r > half);
for (s64 left_sum = 0; left_sum <= mode_all_sum; left_sum++) {
s64 cur = (median_r - median_l) * median;
cur += calc(left, left_sum, freq, 0, median - 1, median_l);
cur += calc(right, mode_all_sum - left_sum, freq, median + 1, max, n - median_r);
res = std::max(res, cur);
}
}
}
{ // use median as (one of) mode(s)
median_l_min = std::max(median_l_min, median_r_min - freq);
median_l_max = std::min(median_l_max, median_r_max - freq);
if (median_l_max >= median_l_min) {
int median_l = median_l_min;
assert(median_l <= half);
assert(median_l + freq > half);
s64 mode_all_sum = mode * (left + right + 1) - median;
for (s64 left_sum = 0; left_sum <= mode_all_sum; left_sum++) {
s64 cur = freq * median;
cur += calc(left, left_sum, freq, 0, median - 1, median_l);
cur += calc(right, mode_all_sum - left_sum, freq, median + 1, max, n - median_l - freq);
res = std::max(res, cur);
}
}
}
}
}
}
return res;
}
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
namespace Fast {
// max of sum of an array consisting of exactly all_num integers each in [min, max]
// that has exactly mode_num modes
// and they appear exactly mode_freq times each
// and sum of them(unique ones) equals to mode_sum
// -INF if no such array exists
// O(1)
s64 calc(int mode_num, s64 mode_sum, int mode_freq, s64 min, s64 max, int all_num) {
assert(mode_freq > 1);
int other = all_num - mode_num * mode_freq;
assert(other >= 0);
if (!mode_num) {
if (!mode_sum) {
int other_block = other / (mode_freq - 1);
int other_leftover = other % (mode_freq - 1);
if (other_block + !!other_leftover > max - min + 1) return -INF;
return other_block * (max + max - other_block + 1) / 2 * (mode_freq - 1) + other_leftover * (max - other_block);
}
return -INF; // invalid
}
if (max - min + 1 < mode_num) return -INF;
s64 adding = mode_sum - mode_num * (min + min + mode_num - 1) / 2;
if (adding < 0) return -INF;
if (mode_num * (max + max - mode_num + 1) / 2 < mode_sum) return -INF;
int other_block = other / (mode_freq - 1);
int other_leftover = other % (mode_freq - 1);
if (other_block && !other_leftover) other_block--, other_leftover = mode_freq - 1;
if (other_block + !!other_leftover + mode_num > max - min + 1) return -INF;
int default_clearance = max - min + 1 - (other_block + !!other_leftover + mode_num);
int slide = adding / mode_num - default_clearance;
s64 sum = mode_sum * mode_freq + other_block * (max + max - other_block + 1) / 2 * (mode_freq - 1)
+ other_leftover * (max - other_block);
if (slide > 0) {
sum -= (other_leftover + (s64) (mode_freq - 1) * (slide - 1)) * mode_num;
sum -= adding % mode_num * (mode_freq - 1);
} else if (!slide) sum -= adding % mode_num * other_leftover;
return sum;
}
s64 get_candidate(int mode_num, int mode_freq, s64 max, int all_num) {
assert(mode_freq > 1);
int other = all_num - mode_num * mode_freq;
assert(other >= 0);
if (!mode_num) return 0;
int other_used = other / (mode_freq - 1);
if (other % (mode_freq - 1)) other_used++;
return (s64) (max - other_used + max - other_used - mode_num + 1) * mode_num / 2;
}
// max of sum of a non-decreasing array a consisting of exactly n integers each in [0, max] such that
// its median is median and it occupies exactly a[median_l, median_r)
// and there are exactly left modes in a[0, median_l) and exactly right modes in a[median_r, n)
// and their unique sum equals to mode_all_sum(the median itself as a mode does not count)
// and they appear exactly freq times each
// -1 if no such array exists
// O(1)
s64 solve_sub(s64 median, int median_l, int median_r, int freq, s64 mode_all_sum, int left, int right) {
s64 lower = 0, upper = mode_all_sum; // lower and upper bound(both inclusive) of sum of modes on the left side
lower = std::max(lower, (s64) (left - 1) * left / 2);
upper = std::min(upper, (s64) (median - 1 + median - left) * left / 2);
lower = std::max(lower, mode_all_sum - (s64) (max + max - right + 1) * right / 2);
upper = std::min(upper, mode_all_sum - (s64) (median + 1 + median + right) * right / 2);
std::vector<s64> candidates{lower, upper};
candidates.push_back(get_candidate(left, freq, median - 1, median_l));
candidates.push_back(candidates.back() + left);
candidates.push_back(mode_all_sum - get_candidate(right, freq, max, n - median_r));
candidates.push_back(candidates.back() - right);
s64 res = -1;
for (auto left_sum : candidates) if (left_sum >= lower && left_sum <= upper) {
s64 cur = (median_r - median_l) * median;
cur += calc(left, left_sum, freq, 0, median - 1, median_l);
cur += calc(right, mode_all_sum - left_sum, freq, median + 1, max, n - median_r);
res = std::max(res, cur);
}
return res;
}
s64 run(s64 median, s64 mode) {
assert(n & 1);
int half = n / 2;
s64 upper = max - median;
s64 res = -1;
// freq == 1
if (median >= half && median + half <= max) {
s64 min_sum = median + (s64) (half - 1) * half / 2 + (s64) (median + 1 + median + half) * half / 2;
s64 max_sum = median + (s64) (median - 1 + median - half) * half / 2 + (s64) (max + max - half + 1) * half / 2;
if (mode * n >= min_sum && mode * n <= max_sum) res = mode * n;
}
for (int freq = 2; freq <= n; freq++) {
for (int left = 0; left * freq <= half && left <= median; left++) {
for (int right = 0; right * freq <= half && right <= upper; right++) {
int median_l_min = left * freq;
int median_l_max = std::min<s64>(half, left * freq + (median - left) * (freq - 1));
int median_r_max = n - right * freq;
int median_r_min = std::max<s64>(half + 1, n - right * freq - (upper - right) * (freq - 1));
assert(median_l_min <= median_l_max);
assert(median_r_min <= median_r_max);
// std::cerr << freq << " " << left << "," << right << std::endl;
if (left || right) { // don't use median as (one of) mode(s)
s64 mode_all_sum = mode * (left + right);
int median_r = median_r_min;
int median_l = std::max(median_l_min, median_r - (freq - 1));
if (median_l <= half) {
assert(median_r > half);
res = std::max(res, solve_sub(median, median_l, median_r, freq, mode_all_sum, left, right));
}
}
{ // use median as (one of) mode(s)
median_l_min = std::max(median_l_min, median_r_min - freq);
median_l_max = std::min(median_l_max, median_r_max - freq);
if (median_l_max >= median_l_min) {
int median_l = median_l_min;
assert(median_l <= half);
assert(median_l + freq > half);
s64 mode_all_sum = mode * (left + right + 1) - median;
res = std::max(res, solve_sub(median, median_l, median_l + freq, freq, mode_all_sum, left, right));
}
}
}
}
}
return res;
}
};
bool random_check() {
std::random_device rnd_dev;
std::mt19937 rnd(rnd_dev() ^ clock());
for (int median = 0; median <= max; median++) {
for (int mode = 0; mode <= max; mode++) {
int r0 = Gu::run(median, mode);
int r1 = Fast::run(median, mode);
if (r0 != r1) {
std::cerr << "!!!!! FAILED !!!!!" << std::endl;
std::cerr << n << " " << max << " " << median << " " << mode << std::endl;
std::cerr << "correct:" << r0 << " wrong:" << r1 << std::endl;
std::cerr << std::endl;
return false;
}
}
}
return true;
}
void random_check_all() {
int n_local = n;
for (n = 1; n <= n_local; n += 2) random_check();
}
int main() {
for (int i = 0; i < 10; i++) {
n = ri();
max = ri();
int median = ri();
int mode = ri();
printf("%" PRId64 "\n", Fast::run(median, mode));
}
return 0;
}
QCFium