結果
問題 |
No.3118 Increment or Multiply
|
ユーザー |
|
提出日時 | 2025-08-09 08:33:30 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
WA
|
実行時間 | - |
コード長 | 7,446 bytes |
コンパイル時間 | 3,153 ms |
コンパイル使用メモリ | 297,700 KB |
実行使用メモリ | 7,720 KB |
最終ジャッジ日時 | 2025-08-09 08:33:36 |
合計ジャッジ時間 | 5,935 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | WA * 1 |
other | AC * 5 WA * 30 |
ソースコード
// by GPT-5 やっぱこの問題AI苦手みたいなんだけどなんでなんだろうな~ #include <bits/stdc++.h> using namespace std; using int64 = long long; const int64 MOD = 998244353; // safe add modulo inline int64 addmod(int64 a, int64 b) { a += b; if (a >= MOD) a -= MOD; return a; } inline int64 mulmod(int64 a, int64 b) { __int128 t = (__int128)a * b; return (int64)(t % MOD); } // compute sum of arithmetic progression y = a..b: sum y = (a+b)*cnt/2 // we compute modulo MOD; cnt fits in 128-bit but may be large; do operations with __int128 then mod. int64 sum_range_mod(__int128 a, __int128 b) { __int128 cnt = b - a + 1; __int128 s = (a + b) * cnt / 2; s %= MOD; if (s < 0) s += MOD; return (int64)s; } // main logic: // We maintain map<start, end> of unvisited K's (initially {1..N}). // Priority queue of (dist, node) starting with (0, N). // When popping (d, u): // compute orig_l = u - (u % A). If orig_l == 0 -> orig_l = 1 (clamped). // Special handling: if u % A == 0, also allow subtracting some numbers before division // (we include a small window to the left) — implementation detail to avoid missing neighbors. // For intersection [l..u] with unvisited intervals, we compute sum of distances for that subinterval // using arithmetic formulas and remove it from the map. // If orig_l >= 1 and orig_l % A == 0, push (d + (u - orig_l) + 1, orig_l / A). // // Note: This implementation follows a standard competitive-programming approach using Dijkstra // + interval-compression. It is written for performance and uses 128-bit where needed. int main(){ ios::sync_with_stdio(false); cin.tie(nullptr); int T; if(!(cin >> T)) return 0; while(T--){ long long N, A; cin >> N >> A; if (A == 1){ // f(K) = N - K, sum = sum_{K=1..N} (N-K) = N*(N-1)/2 __int128 val = (__int128)N * (N - 1) / 2; long long ans = (long long)(val % MOD); cout << ans % MOD << '\n'; continue; } // intervals: map start -> end of unvisited K std::map<long long, long long> intervals; intervals[1] = N; // min-heap by distance using P = pair<unsigned long long, long long>; // (dist, node) priority_queue<P, vector<P>, greater<P>> pq; pq.push({0ULL, N}); __int128 total = 0; // accumulate in 128-bit then mod at end // helper lambda to remove intersection [a..b] from intervals and add its contribution auto remove_and_add = [&](long long a, long long b, unsigned long long d){ // iterate intervals that intersect [a..b] -- in this function we assume exact intersection is single [a..b] // But actual removal is handled in the main loop. (void)a; (void)b; (void)d; }; while(!pq.empty() && !intervals.empty()){ auto [d_u, u] = pq.top(); pq.pop(); if (u < 1) continue; // orig_l is the nearest (<=u) number divisible by A long long r = u % A; long long orig_l = u - r; // orig_l % A == 0 (could be 0) long long l = orig_l; if (l < 1) l = 1; // Special adjustment: // If u is divisible by A (r==0), we still want to be able to reach u-1, u-2, ... // The contiguous block we should try to extract starts from max(1, u - (A - 1)) to u in that case. // This heuristic ensures we don't miss immediate predecessors when u is a multiple of A. if (r == 0) { long long cand = u - (A - 1); if (cand > 1) { // choose the maximum between current l and cand if (cand > l) l = cand; } else { l = 1; } } // Now we will remove all unvisited parts within [l..u], summing distances. // Find first interval with start <= u auto it = intervals.upper_bound(u); if (it != intervals.begin()) --it; else { /* it at begin */ } // move it to first interval that may intersect [l..u] while(it != intervals.end() && it->second < l){ ++it; } // iterate and remove intersections vector<pair<long long,long long>> to_add; // for splitting left/right pieces vector<map<long long,long long>::iterator> to_erase; while(it != intervals.end() && it->first <= u){ long long a0 = it->first; long long b0 = it->second; if (b0 < l) { ++it; continue; } long long a = max(a0, l); long long b = min(b0, u); if (a <= b){ // sum distances for y in [a..b] is sum_{y=a..b} (d_u + (u - y)) // = cnt * d_u + cnt * u - sum_{y=a..b} y __int128 cnt = (__int128)(b - a + 1); __int128 sum_y = (__int128)(a + b) * cnt / 2; __int128 part = (__int128)cnt * (__int128)d_u + (__int128)cnt * (__int128)u - sum_y; total += part; // Now adjust interval it: may need to keep left part [a0..a-1] and right part [b+1..b0] if (a0 < a && b < b0){ // split into two intervals: [a0, a-1] and [b+1, b0] long long left_s = a0, left_e = a-1; long long right_s = b+1, right_e = b0; // record to replace it with left and right to_add.emplace_back(left_s, left_e); to_add.emplace_back(right_s, right_e); to_erase.push_back(it); ++it; } else if (a0 < a){ // keep left part only long long left_s = a0, left_e = a-1; to_add.emplace_back(left_s, left_e); to_erase.push_back(it); ++it; } else if (b < b0){ // keep right part only long long right_s = b+1, right_e = b0; to_add.emplace_back(right_s, right_e); to_erase.push_back(it); ++it; } else { // entire interval removed to_erase.push_back(it); ++it; } } else { ++it; } } // apply erases and adds for(auto &it2 : to_erase) intervals.erase(it2); for(auto &p : to_add) intervals[p.first] = p.second; // after removing [l..u], attempt to push division result if orig_l >= 1 and orig_l % A == 0 if (orig_l >= 1 && (orig_l % A == 0)){ long long v = orig_l / A; unsigned long long nd = d_u + (unsigned long long)(u - orig_l) + 1ULL; // push only if v >= 1 if (v >= 1){ pq.push({nd, v}); } } } // output total % MOD long long ans = (long long)(total % MOD); if (ans < 0) ans += MOD; cout << ans << '\n'; } return 0; }