結果
問題 |
No.3118 Increment or Multiply
|
ユーザー |
|
提出日時 | 2025-05-21 16:01:07 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
WA
|
実行時間 | - |
コード長 | 7,963 bytes |
コンパイル時間 | 1,823 ms |
コンパイル使用メモリ | 113,804 KB |
実行使用メモリ | 7,844 KB |
最終ジャッジ日時 | 2025-05-21 16:01:22 |
合計ジャッジ時間 | 13,710 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 25 WA * 10 |
ソースコード
#include <iostream> #include <vector> #include <algorithm> #include <map> // Not strictly needed, but useful for some approaches // Using long long for N, A, and costs/coefficients using ll = long long; const ll MOD = 998244353; const ll INF_COST = 4e18; // A very large number for cost comparison struct Func { ll const_term; // C0 in C0 + C1*K ll k_term; // C1 in C0 + C1*K (this will be negative, e.g. MOD-Am or MOD-1) ll min_K_domain; ll max_K_domain; }; // Modular inverse for division by 2 ll inv2 = (MOD + 1) / 2; // Sum of K from L to R, modulo MOD ll sum_arithmetic_progression_K(ll L, ll R) { if (L > R) return 0; ll n = (R - L + 1); // Number of terms // Sum = (L+R) * n / 2 ll term1 = (L % MOD + R % MOD) % MOD; ll term2 = (n % MOD * inv2 % MOD); return (term1 * term2) % MOD; } ll solve() { ll N, A; std::cin >> N >> A; if (A == 1) { // f(K) = N - K // Sum (N-K) for K=1 to N = N*N - N(N+1)/2 = N(N-1)/2 ll N_mod = N % MOD; ll num = (N_mod * ((N_mod - 1 + MOD) % MOD)) % MOD; ll ans = (num * inv2) % MOD; return ans; } std::vector<Func> funcs; // Base case: m=0 (only additions) // f(K) = N - K. const_term = N, k_term = -1 funcs.push_back({N % MOD, (MOD - 1), 1, N}); ll current_A_power = 1; // Stores A^m for (int m = 1; m <= 62; ++m) { // m up to ~60-62 for N=10^18, A=2 ll prev_A_power = current_A_power; if (A > INF_COST / prev_A_power) { // Check for overflow before multiplication current_A_power = INF_COST; // Mark as too large } else { current_A_power = prev_A_power * A; } if (current_A_power > N || current_A_power == 0) { // current_A_power = 0 if overflow occurred weirdly break; } // Strategy 1 (c0=0): cost = m + N - K * A^m // const_term = m+N, k_term = -A^m ll max_K_strat1 = N / current_A_power; if (max_K_strat1 >= 1) { ll const_t = (m + N % MOD) % MOD; ll k_t = (MOD - (current_A_power % MOD)) % MOD; funcs.push_back({const_t, k_t, 1, max_K_strat1}); } // Strategy 2 (c0>0 optimal): cost = (X_m-K) + m + N - X_m*A^m // X_m = floor(N/A^m) // const_term = X_m + m + N - X_m*A^m, k_term = -1 ll X_m = N / current_A_power; if (X_m >= 1) { // K must be <= X_m, and K >= 1 ll Xm_mod = X_m % MOD; ll Am_mod = current_A_power % MOD; ll term_Xm_Am = (Xm_mod * Am_mod) % MOD; ll const_t = (Xm_mod + m % MOD + N % MOD - term_Xm_Am + MOD) % MOD; ll k_t = (MOD - 1); funcs.push_back({const_t, k_t, 1, X_m}); } } std::vector<ll> points; points.push_back(1); points.push_back(N + 1); // Sentinel for last interval N to N for (size_t i = 0; i < funcs.size(); ++i) { for (size_t j = i + 1; j < funcs.size(); ++j) { // Intersection K of C_i + M_i*K = C_j + M_j*K // (M_i - M_j)K = C_j - C_i // K = (C_j - C_i) / (M_i - M_j) ll Ci = funcs[i].const_term; ll Mi = funcs[i].k_term; // This is already (MOD - val) form for negative ll Cj = funcs[j].const_term; ll Mj = funcs[j].k_term; ll num = (Cj - Ci + MOD) % MOD; ll den = (Mi - Mj + MOD) % MOD; if (den == 0) continue; // Parallel lines or same line // We need K_intersect. For K = num * den_inv (mod MOD). This is not for real K. // We need real K where functions cross. // Original coefficients: M_i_orig = -funcs[i].<actual_coeff_of_K> // Let's use original expressions: const_val_i + k_coeff_i * K (k_coeff_i is negative) // (k_coeff_i - k_coeff_j) * K = const_val_j - const_val_i // K_exact = (const_val_j - const_val_i) / (k_coeff_i - k_coeff_j) // Need to use non-modular values for k_coeff for division // Reconstruct true coefficients for intersection calculation // This part is tricky with modular arithmetic for coefficients. // It's easier to evaluate at K_mid for each interval. // The number of pieces is O(M), so few critical points. // For simplicity, the critical points are usually just domain boundaries. // Max_K_domain for each function: points.push_back(funcs[i].max_K_domain + 1); points.push_back(funcs[j].max_K_domain + 1); } } std::sort(points.begin(), points.end()); points.erase(std::unique(points.begin(), points.end()), points.end()); ll total_f_sum = 0; for (ll K_start : points) { if (K_start > N) continue; auto it = std::upper_bound(points.begin(), points.end(), K_start); ll K_end = N; if (it != points.end()) { K_end = std::min(N, (*it) - 1); } if (K_start > K_end) continue; ll K_mid = K_start; // Test point for the interval ll min_true_cost_at_K_mid = N - K_mid; // From m=0 case Func best_func = funcs[0]; // Default to m=0 function ll temp_A_power = 1; for (int m = 1; m <= 62; ++m) { ll prev_A_power = temp_A_power; if (A > INF_COST / prev_A_power) { temp_A_power = INF_COST; } else { temp_A_power = prev_A_power * A; } if (temp_A_power > N || temp_A_power == 0) break; // Strategy 1 (c0=0) if (K_mid * temp_A_power <= N) { // Valid K_mid for this strategy ll cost1 = m + N - K_mid * temp_A_power; if (cost1 < min_true_cost_at_K_mid) { min_true_cost_at_K_mid = cost1; ll const_t = (m % MOD + N % MOD) % MOD; ll k_t = (MOD - (temp_A_power % MOD)) % MOD; best_func = {const_t, k_t, 1, N / temp_A_power}; } } // Strategy 2 (c0>0 optimal) ll X_m = N / temp_A_power; if (K_mid <= X_m) { // Valid K_mid for this strategy ll cost2 = (X_m - K_mid) + m + N - X_m * temp_A_power; if (cost2 < min_true_cost_at_K_mid) { min_true_cost_at_K_mid = cost2; ll Xm_mod = X_m % MOD; ll Am_mod = temp_A_power % MOD; ll term_Xm_Am = (Xm_mod * Am_mod) % MOD; ll const_t = (Xm_mod + m % MOD + N % MOD - term_Xm_Am + MOD) % MOD; ll k_t = (MOD - 1); best_func = {const_t, k_t, 1, X_m}; } } } // Ensure the chosen best_func is valid for the entire [K_start, K_end] range. // The K_mid test should ideally be robust. The lower envelope of lines means one function is optimal over an interval. // Max_K_domain check is important. The interval [K_start, K_end] must be within best_func's domain. ll actual_K_end = std::min(K_end, best_func.max_K_domain); if (K_start > actual_K_end) continue; ll C0 = best_func.const_term; ll C1 = best_func.k_term; // This is already (MOD - val) form ll num_terms = (actual_K_end - K_start + 1); ll sum_C0 = (num_terms % MOD * C0) % MOD; ll sum_K = sum_arithmetic_progression_K(K_start, actual_K_end); ll sum_C1K = (C1 * sum_K) % MOD; ll current_interval_sum = (sum_C0 + sum_C1K + MOD) % MOD; total_f_sum = (total_f_sum + current_interval_sum) % MOD; } return total_f_sum; } int main() { std::ios_base::sync_with_stdio(false); std::cin.tie(NULL); int T; std::cin >> T; while (T--) { std::cout << solve() << "\n"; } return 0; }