結果
| 問題 |
No.3118 Increment or Multiply
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2025-05-21 16:01:07 |
| 言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 7,963 bytes |
| コンパイル時間 | 1,823 ms |
| コンパイル使用メモリ | 113,804 KB |
| 実行使用メモリ | 7,844 KB |
| 最終ジャッジ日時 | 2025-05-21 16:01:22 |
| 合計ジャッジ時間 | 13,710 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 25 WA * 10 |
ソースコード
#include <iostream>
#include <vector>
#include <algorithm>
#include <map> // Not strictly needed, but useful for some approaches
// Using long long for N, A, and costs/coefficients
using ll = long long;
const ll MOD = 998244353;
const ll INF_COST = 4e18; // A very large number for cost comparison
struct Func {
ll const_term; // C0 in C0 + C1*K
ll k_term; // C1 in C0 + C1*K (this will be negative, e.g. MOD-Am or MOD-1)
ll min_K_domain;
ll max_K_domain;
};
// Modular inverse for division by 2
ll inv2 = (MOD + 1) / 2;
// Sum of K from L to R, modulo MOD
ll sum_arithmetic_progression_K(ll L, ll R) {
if (L > R) return 0;
ll n = (R - L + 1); // Number of terms
// Sum = (L+R) * n / 2
ll term1 = (L % MOD + R % MOD) % MOD;
ll term2 = (n % MOD * inv2 % MOD);
return (term1 * term2) % MOD;
}
ll solve() {
ll N, A;
std::cin >> N >> A;
if (A == 1) {
// f(K) = N - K
// Sum (N-K) for K=1 to N = N*N - N(N+1)/2 = N(N-1)/2
ll N_mod = N % MOD;
ll num = (N_mod * ((N_mod - 1 + MOD) % MOD)) % MOD;
ll ans = (num * inv2) % MOD;
return ans;
}
std::vector<Func> funcs;
// Base case: m=0 (only additions)
// f(K) = N - K. const_term = N, k_term = -1
funcs.push_back({N % MOD, (MOD - 1), 1, N});
ll current_A_power = 1; // Stores A^m
for (int m = 1; m <= 62; ++m) { // m up to ~60-62 for N=10^18, A=2
ll prev_A_power = current_A_power;
if (A > INF_COST / prev_A_power) { // Check for overflow before multiplication
current_A_power = INF_COST; // Mark as too large
} else {
current_A_power = prev_A_power * A;
}
if (current_A_power > N || current_A_power == 0) { // current_A_power = 0 if overflow occurred weirdly
break;
}
// Strategy 1 (c0=0): cost = m + N - K * A^m
// const_term = m+N, k_term = -A^m
ll max_K_strat1 = N / current_A_power;
if (max_K_strat1 >= 1) {
ll const_t = (m + N % MOD) % MOD;
ll k_t = (MOD - (current_A_power % MOD)) % MOD;
funcs.push_back({const_t, k_t, 1, max_K_strat1});
}
// Strategy 2 (c0>0 optimal): cost = (X_m-K) + m + N - X_m*A^m
// X_m = floor(N/A^m)
// const_term = X_m + m + N - X_m*A^m, k_term = -1
ll X_m = N / current_A_power;
if (X_m >= 1) { // K must be <= X_m, and K >= 1
ll Xm_mod = X_m % MOD;
ll Am_mod = current_A_power % MOD;
ll term_Xm_Am = (Xm_mod * Am_mod) % MOD;
ll const_t = (Xm_mod + m % MOD + N % MOD - term_Xm_Am + MOD) % MOD;
ll k_t = (MOD - 1);
funcs.push_back({const_t, k_t, 1, X_m});
}
}
std::vector<ll> points;
points.push_back(1);
points.push_back(N + 1); // Sentinel for last interval N to N
for (size_t i = 0; i < funcs.size(); ++i) {
for (size_t j = i + 1; j < funcs.size(); ++j) {
// Intersection K of C_i + M_i*K = C_j + M_j*K
// (M_i - M_j)K = C_j - C_i
// K = (C_j - C_i) / (M_i - M_j)
ll Ci = funcs[i].const_term;
ll Mi = funcs[i].k_term; // This is already (MOD - val) form for negative
ll Cj = funcs[j].const_term;
ll Mj = funcs[j].k_term;
ll num = (Cj - Ci + MOD) % MOD;
ll den = (Mi - Mj + MOD) % MOD;
if (den == 0) continue; // Parallel lines or same line
// We need K_intersect. For K = num * den_inv (mod MOD). This is not for real K.
// We need real K where functions cross.
// Original coefficients: M_i_orig = -funcs[i].<actual_coeff_of_K>
// Let's use original expressions: const_val_i + k_coeff_i * K (k_coeff_i is negative)
// (k_coeff_i - k_coeff_j) * K = const_val_j - const_val_i
// K_exact = (const_val_j - const_val_i) / (k_coeff_i - k_coeff_j)
// Need to use non-modular values for k_coeff for division
// Reconstruct true coefficients for intersection calculation
// This part is tricky with modular arithmetic for coefficients.
// It's easier to evaluate at K_mid for each interval.
// The number of pieces is O(M), so few critical points.
// For simplicity, the critical points are usually just domain boundaries.
// Max_K_domain for each function:
points.push_back(funcs[i].max_K_domain + 1);
points.push_back(funcs[j].max_K_domain + 1);
}
}
std::sort(points.begin(), points.end());
points.erase(std::unique(points.begin(), points.end()), points.end());
ll total_f_sum = 0;
for (ll K_start : points) {
if (K_start > N) continue;
auto it = std::upper_bound(points.begin(), points.end(), K_start);
ll K_end = N;
if (it != points.end()) {
K_end = std::min(N, (*it) - 1);
}
if (K_start > K_end) continue;
ll K_mid = K_start; // Test point for the interval
ll min_true_cost_at_K_mid = N - K_mid; // From m=0 case
Func best_func = funcs[0]; // Default to m=0 function
ll temp_A_power = 1;
for (int m = 1; m <= 62; ++m) {
ll prev_A_power = temp_A_power;
if (A > INF_COST / prev_A_power) {
temp_A_power = INF_COST;
} else {
temp_A_power = prev_A_power * A;
}
if (temp_A_power > N || temp_A_power == 0) break;
// Strategy 1 (c0=0)
if (K_mid * temp_A_power <= N) { // Valid K_mid for this strategy
ll cost1 = m + N - K_mid * temp_A_power;
if (cost1 < min_true_cost_at_K_mid) {
min_true_cost_at_K_mid = cost1;
ll const_t = (m % MOD + N % MOD) % MOD;
ll k_t = (MOD - (temp_A_power % MOD)) % MOD;
best_func = {const_t, k_t, 1, N / temp_A_power};
}
}
// Strategy 2 (c0>0 optimal)
ll X_m = N / temp_A_power;
if (K_mid <= X_m) { // Valid K_mid for this strategy
ll cost2 = (X_m - K_mid) + m + N - X_m * temp_A_power;
if (cost2 < min_true_cost_at_K_mid) {
min_true_cost_at_K_mid = cost2;
ll Xm_mod = X_m % MOD;
ll Am_mod = temp_A_power % MOD;
ll term_Xm_Am = (Xm_mod * Am_mod) % MOD;
ll const_t = (Xm_mod + m % MOD + N % MOD - term_Xm_Am + MOD) % MOD;
ll k_t = (MOD - 1);
best_func = {const_t, k_t, 1, X_m};
}
}
}
// Ensure the chosen best_func is valid for the entire [K_start, K_end] range.
// The K_mid test should ideally be robust. The lower envelope of lines means one function is optimal over an interval.
// Max_K_domain check is important. The interval [K_start, K_end] must be within best_func's domain.
ll actual_K_end = std::min(K_end, best_func.max_K_domain);
if (K_start > actual_K_end) continue;
ll C0 = best_func.const_term;
ll C1 = best_func.k_term; // This is already (MOD - val) form
ll num_terms = (actual_K_end - K_start + 1);
ll sum_C0 = (num_terms % MOD * C0) % MOD;
ll sum_K = sum_arithmetic_progression_K(K_start, actual_K_end);
ll sum_C1K = (C1 * sum_K) % MOD;
ll current_interval_sum = (sum_C0 + sum_C1K + MOD) % MOD;
total_f_sum = (total_f_sum + current_interval_sum) % MOD;
}
return total_f_sum;
}
int main() {
std::ios_base::sync_with_stdio(false);
std::cin.tie(NULL);
int T;
std::cin >> T;
while (T--) {
std::cout << solve() << "\n";
}
return 0;
}