結果

問題 No.3118 Increment or Multiply
ユーザー Naru820
提出日時 2025-05-21 16:01:07
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 7,963 bytes
コンパイル時間 1,823 ms
コンパイル使用メモリ 113,804 KB
実行使用メモリ 7,844 KB
最終ジャッジ日時 2025-05-21 16:01:22
合計ジャッジ時間 13,710 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 25 WA * 10
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <vector>
#include <algorithm>
#include <map> // Not strictly needed, but useful for some approaches

// Using long long for N, A, and costs/coefficients
using ll = long long;

const ll MOD = 998244353;
const ll INF_COST = 4e18; // A very large number for cost comparison

struct Func {
    ll const_term; // C0 in C0 + C1*K
    ll k_term;     // C1 in C0 + C1*K (this will be negative, e.g. MOD-Am or MOD-1)
    ll min_K_domain;
    ll max_K_domain;
};

// Modular inverse for division by 2
ll inv2 = (MOD + 1) / 2;

// Sum of K from L to R, modulo MOD
ll sum_arithmetic_progression_K(ll L, ll R) {
    if (L > R) return 0;
    ll n = (R - L + 1); // Number of terms
    // Sum = (L+R) * n / 2
    ll term1 = (L % MOD + R % MOD) % MOD;
    ll term2 = (n % MOD * inv2 % MOD);
    return (term1 * term2) % MOD;
}


ll solve() {
    ll N, A;
    std::cin >> N >> A;

    if (A == 1) {
        // f(K) = N - K
        // Sum (N-K) for K=1 to N = N*N - N(N+1)/2 = N(N-1)/2
        ll N_mod = N % MOD;
        ll num = (N_mod * ((N_mod - 1 + MOD) % MOD)) % MOD;
        ll ans = (num * inv2) % MOD;
        return ans;
    }

    std::vector<Func> funcs;

    // Base case: m=0 (only additions)
    // f(K) = N - K. const_term = N, k_term = -1
    funcs.push_back({N % MOD, (MOD - 1), 1, N});

    ll current_A_power = 1; // Stores A^m
    for (int m = 1; m <= 62; ++m) { // m up to ~60-62 for N=10^18, A=2
        ll prev_A_power = current_A_power;
        if (A > INF_COST / prev_A_power) { // Check for overflow before multiplication
            current_A_power = INF_COST; // Mark as too large
        } else {
            current_A_power = prev_A_power * A;
        }

        if (current_A_power > N || current_A_power == 0) { // current_A_power = 0 if overflow occurred weirdly
            break;
        }

        // Strategy 1 (c0=0): cost = m + N - K * A^m
        // const_term = m+N, k_term = -A^m
        ll max_K_strat1 = N / current_A_power;
        if (max_K_strat1 >= 1) {
            ll const_t = (m + N % MOD) % MOD;
            ll k_t = (MOD - (current_A_power % MOD)) % MOD;
            funcs.push_back({const_t, k_t, 1, max_K_strat1});
        }

        // Strategy 2 (c0>0 optimal): cost = (X_m-K) + m + N - X_m*A^m
        // X_m = floor(N/A^m)
        // const_term = X_m + m + N - X_m*A^m, k_term = -1
        ll X_m = N / current_A_power;
        if (X_m >= 1) { // K must be <= X_m, and K >= 1
            ll Xm_mod = X_m % MOD;
            ll Am_mod = current_A_power % MOD;
            ll term_Xm_Am = (Xm_mod * Am_mod) % MOD;
            
            ll const_t = (Xm_mod + m % MOD + N % MOD - term_Xm_Am + MOD) % MOD;
            ll k_t = (MOD - 1);
            funcs.push_back({const_t, k_t, 1, X_m});
        }
    }

    std::vector<ll> points;
    points.push_back(1);
    points.push_back(N + 1); // Sentinel for last interval N to N

    for (size_t i = 0; i < funcs.size(); ++i) {
        for (size_t j = i + 1; j < funcs.size(); ++j) {
            // Intersection K of C_i + M_i*K = C_j + M_j*K
            // (M_i - M_j)K = C_j - C_i
            // K = (C_j - C_i) / (M_i - M_j)
            ll Ci = funcs[i].const_term;
            ll Mi = funcs[i].k_term; // This is already (MOD - val) form for negative
            ll Cj = funcs[j].const_term;
            ll Mj = funcs[j].k_term;

            ll num = (Cj - Ci + MOD) % MOD;
            ll den = (Mi - Mj + MOD) % MOD;

            if (den == 0) continue; // Parallel lines or same line

            // We need K_intersect. For K = num * den_inv (mod MOD). This is not for real K.
            // We need real K where functions cross.
            // Original coefficients: M_i_orig = -funcs[i].<actual_coeff_of_K>
            // Let's use original expressions: const_val_i + k_coeff_i * K (k_coeff_i is negative)
            // (k_coeff_i - k_coeff_j) * K = const_val_j - const_val_i
            // K_exact = (const_val_j - const_val_i) / (k_coeff_i - k_coeff_j)
            // Need to use non-modular values for k_coeff for division
            
            // Reconstruct true coefficients for intersection calculation
            // This part is tricky with modular arithmetic for coefficients.
            // It's easier to evaluate at K_mid for each interval.
            // The number of pieces is O(M), so few critical points.
            // For simplicity, the critical points are usually just domain boundaries.
            // Max_K_domain for each function:
            points.push_back(funcs[i].max_K_domain + 1);
            points.push_back(funcs[j].max_K_domain + 1);
        }
    }
    
    std::sort(points.begin(), points.end());
    points.erase(std::unique(points.begin(), points.end()), points.end());

    ll total_f_sum = 0;

    for (ll K_start : points) {
        if (K_start > N) continue;
        
        auto it = std::upper_bound(points.begin(), points.end(), K_start);
        ll K_end = N;
        if (it != points.end()) {
            K_end = std::min(N, (*it) - 1);
        }
        
        if (K_start > K_end) continue;

        ll K_mid = K_start; // Test point for the interval

        ll min_true_cost_at_K_mid = N - K_mid; // From m=0 case
        Func best_func = funcs[0]; // Default to m=0 function

        ll temp_A_power = 1;
        for (int m = 1; m <= 62; ++m) {
            ll prev_A_power = temp_A_power;
             if (A > INF_COST / prev_A_power) { 
                temp_A_power = INF_COST;
            } else {
                temp_A_power = prev_A_power * A;
            }
            if (temp_A_power > N || temp_A_power == 0) break;

            // Strategy 1 (c0=0)
            if (K_mid * temp_A_power <= N) { // Valid K_mid for this strategy
                ll cost1 = m + N - K_mid * temp_A_power;
                if (cost1 < min_true_cost_at_K_mid) {
                    min_true_cost_at_K_mid = cost1;
                    ll const_t = (m % MOD + N % MOD) % MOD;
                    ll k_t = (MOD - (temp_A_power % MOD)) % MOD;
                    best_func = {const_t, k_t, 1, N / temp_A_power};
                }
            }

            // Strategy 2 (c0>0 optimal)
            ll X_m = N / temp_A_power;
            if (K_mid <= X_m) { // Valid K_mid for this strategy
                ll cost2 = (X_m - K_mid) + m + N - X_m * temp_A_power;
                if (cost2 < min_true_cost_at_K_mid) {
                    min_true_cost_at_K_mid = cost2;
                    ll Xm_mod = X_m % MOD;
                    ll Am_mod = temp_A_power % MOD;
                    ll term_Xm_Am = (Xm_mod * Am_mod) % MOD;
                    ll const_t = (Xm_mod + m % MOD + N % MOD - term_Xm_Am + MOD) % MOD;
                    ll k_t = (MOD - 1);
                    best_func = {const_t, k_t, 1, X_m};
                }
            }
        }
        
        // Ensure the chosen best_func is valid for the entire [K_start, K_end] range.
        // The K_mid test should ideally be robust. The lower envelope of lines means one function is optimal over an interval.
        // Max_K_domain check is important. The interval [K_start, K_end] must be within best_func's domain.
        ll actual_K_end = std::min(K_end, best_func.max_K_domain);
        if (K_start > actual_K_end) continue;


        ll C0 = best_func.const_term;
        ll C1 = best_func.k_term; // This is already (MOD - val) form

        ll num_terms = (actual_K_end - K_start + 1);
        ll sum_C0 = (num_terms % MOD * C0) % MOD;
        
        ll sum_K = sum_arithmetic_progression_K(K_start, actual_K_end);
        ll sum_C1K = (C1 * sum_K) % MOD;
        
        ll current_interval_sum = (sum_C0 + sum_C1K + MOD) % MOD;
        total_f_sum = (total_f_sum + current_interval_sum) % MOD;
    }
    return total_f_sum;
}

int main() {
    std::ios_base::sync_with_stdio(false);
    std::cin.tie(NULL);
    int T;
    std::cin >> T;
    while (T--) {
        std::cout << solve() << "\n";
    }
    return 0;
}
0