結果

問題 No.3118 Increment or Multiply
ユーザー Naru820
提出日時 2025-08-09 08:33:30
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 7,446 bytes
コンパイル時間 3,153 ms
コンパイル使用メモリ 297,700 KB
実行使用メモリ 7,720 KB
最終ジャッジ日時 2025-08-09 08:33:36
合計ジャッジ時間 5,935 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample WA * 1
other AC * 5 WA * 30
権限があれば一括ダウンロードができます

ソースコード

diff #

// by GPT-5 やっぱこの問題AI苦手みたいなんだけどなんでなんだろうな~
#include <bits/stdc++.h>
using namespace std;
using int64 = long long;
const int64 MOD = 998244353;

// safe add modulo
inline int64 addmod(int64 a, int64 b) {
    a += b;
    if (a >= MOD) a -= MOD;
    return a;
}
inline int64 mulmod(int64 a, int64 b) {
    __int128 t = (__int128)a * b;
    return (int64)(t % MOD);
}

// compute sum of arithmetic progression y = a..b: sum y = (a+b)*cnt/2
// we compute modulo MOD; cnt fits in 128-bit but may be large; do operations with __int128 then mod.
int64 sum_range_mod(__int128 a, __int128 b) {
    __int128 cnt = b - a + 1;
    __int128 s = (a + b) * cnt / 2;
    s %= MOD;
    if (s < 0) s += MOD;
    return (int64)s;
}

// main logic:
// We maintain map<start, end> of unvisited K's (initially {1..N}).
// Priority queue of (dist, node) starting with (0, N).
// When popping (d, u):
//   compute orig_l = u - (u % A). If orig_l == 0 -> orig_l = 1 (clamped).
//   Special handling: if u % A == 0, also allow subtracting some numbers before division
//      (we include a small window to the left) — implementation detail to avoid missing neighbors.
//   For intersection [l..u] with unvisited intervals, we compute sum of distances for that subinterval
//      using arithmetic formulas and remove it from the map.
//   If orig_l >= 1 and orig_l % A == 0, push (d + (u - orig_l) + 1, orig_l / A).
//
// Note: This implementation follows a standard competitive-programming approach using Dijkstra
// + interval-compression. It is written for performance and uses 128-bit where needed.

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    int T;
    if(!(cin >> T)) return 0;
    while(T--){
        long long N, A;
        cin >> N >> A;
        if (A == 1){
            // f(K) = N - K, sum = sum_{K=1..N} (N-K) = N*(N-1)/2
            __int128 val = (__int128)N * (N - 1) / 2;
            long long ans = (long long)(val % MOD);
            cout << ans % MOD << '\n';
            continue;
        }

        // intervals: map start -> end of unvisited K
        std::map<long long, long long> intervals;
        intervals[1] = N;

        // min-heap by distance
        using P = pair<unsigned long long, long long>; // (dist, node)
        priority_queue<P, vector<P>, greater<P>> pq;
        pq.push({0ULL, N});

        __int128 total = 0; // accumulate in 128-bit then mod at end

        // helper lambda to remove intersection [a..b] from intervals and add its contribution
        auto remove_and_add = [&](long long a, long long b, unsigned long long d){
            // iterate intervals that intersect [a..b] -- in this function we assume exact intersection is single [a..b]
            // But actual removal is handled in the main loop.
            (void)a; (void)b; (void)d;
        };

        while(!pq.empty() && !intervals.empty()){
            auto [d_u, u] = pq.top(); pq.pop();
            if (u < 1) continue;

            // orig_l is the nearest (<=u) number divisible by A
            long long r = u % A;
            long long orig_l = u - r; // orig_l % A == 0 (could be 0)
            long long l = orig_l;
            if (l < 1) l = 1;

            // Special adjustment:
            // If u is divisible by A (r==0), we still want to be able to reach u-1, u-2, ...
            // The contiguous block we should try to extract starts from max(1, u - (A - 1)) to u in that case.
            // This heuristic ensures we don't miss immediate predecessors when u is a multiple of A.
            if (r == 0) {
                long long cand = u - (A - 1);
                if (cand > 1) {
                    // choose the maximum between current l and cand
                    if (cand > l) l = cand;
                } else {
                    l = 1;
                }
            }

            // Now we will remove all unvisited parts within [l..u], summing distances.
            // Find first interval with start <= u
            auto it = intervals.upper_bound(u);
            if (it != intervals.begin()) --it; else { /* it at begin */ }
            // move it to first interval that may intersect [l..u]
            while(it != intervals.end() && it->second < l){
                ++it;
            }
            // iterate and remove intersections
            vector<pair<long long,long long>> to_add; // for splitting left/right pieces
            vector<map<long long,long long>::iterator> to_erase;
            while(it != intervals.end() && it->first <= u){
                long long a0 = it->first;
                long long b0 = it->second;
                if (b0 < l) { ++it; continue; }
                long long a = max(a0, l);
                long long b = min(b0, u);
                if (a <= b){
                    // sum distances for y in [a..b] is sum_{y=a..b} (d_u + (u - y))
                    // = cnt * d_u + cnt * u - sum_{y=a..b} y
                    __int128 cnt = (__int128)(b - a + 1);
                    __int128 sum_y = (__int128)(a + b) * cnt / 2;
                    __int128 part = (__int128)cnt * (__int128)d_u + (__int128)cnt * (__int128)u - sum_y;
                    total += part;
                    // Now adjust interval it: may need to keep left part [a0..a-1] and right part [b+1..b0]
                    if (a0 < a && b < b0){
                        // split into two intervals: [a0, a-1] and [b+1, b0]
                        long long left_s = a0, left_e = a-1;
                        long long right_s = b+1, right_e = b0;
                        // record to replace it with left and right
                        to_add.emplace_back(left_s, left_e);
                        to_add.emplace_back(right_s, right_e);
                        to_erase.push_back(it);
                        ++it;
                    } else if (a0 < a){
                        // keep left part only
                        long long left_s = a0, left_e = a-1;
                        to_add.emplace_back(left_s, left_e);
                        to_erase.push_back(it);
                        ++it;
                    } else if (b < b0){
                        // keep right part only
                        long long right_s = b+1, right_e = b0;
                        to_add.emplace_back(right_s, right_e);
                        to_erase.push_back(it);
                        ++it;
                    } else {
                        // entire interval removed
                        to_erase.push_back(it);
                        ++it;
                    }
                } else {
                    ++it;
                }
            }
            // apply erases and adds
            for(auto &it2 : to_erase) intervals.erase(it2);
            for(auto &p : to_add) intervals[p.first] = p.second;

            // after removing [l..u], attempt to push division result if orig_l >= 1 and orig_l % A == 0
            if (orig_l >= 1 && (orig_l % A == 0)){
                long long v = orig_l / A;
                unsigned long long nd = d_u + (unsigned long long)(u - orig_l) + 1ULL;
                // push only if v >= 1
                if (v >= 1){
                    pq.push({nd, v});
                }
            }
        }

        // output total % MOD
        long long ans = (long long)(total % MOD);
        if (ans < 0) ans += MOD;
        cout << ans << '\n';
    }
    return 0;
}
0