結果
| 問題 |
No.3118 Increment or Multiply
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2025-08-09 08:33:30 |
| 言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 7,446 bytes |
| コンパイル時間 | 3,153 ms |
| コンパイル使用メモリ | 297,700 KB |
| 実行使用メモリ | 7,720 KB |
| 最終ジャッジ日時 | 2025-08-09 08:33:36 |
| 合計ジャッジ時間 | 5,935 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | WA * 1 |
| other | AC * 5 WA * 30 |
ソースコード
// by GPT-5 やっぱこの問題AI苦手みたいなんだけどなんでなんだろうな~
#include <bits/stdc++.h>
using namespace std;
using int64 = long long;
const int64 MOD = 998244353;
// safe add modulo
inline int64 addmod(int64 a, int64 b) {
a += b;
if (a >= MOD) a -= MOD;
return a;
}
inline int64 mulmod(int64 a, int64 b) {
__int128 t = (__int128)a * b;
return (int64)(t % MOD);
}
// compute sum of arithmetic progression y = a..b: sum y = (a+b)*cnt/2
// we compute modulo MOD; cnt fits in 128-bit but may be large; do operations with __int128 then mod.
int64 sum_range_mod(__int128 a, __int128 b) {
__int128 cnt = b - a + 1;
__int128 s = (a + b) * cnt / 2;
s %= MOD;
if (s < 0) s += MOD;
return (int64)s;
}
// main logic:
// We maintain map<start, end> of unvisited K's (initially {1..N}).
// Priority queue of (dist, node) starting with (0, N).
// When popping (d, u):
// compute orig_l = u - (u % A). If orig_l == 0 -> orig_l = 1 (clamped).
// Special handling: if u % A == 0, also allow subtracting some numbers before division
// (we include a small window to the left) — implementation detail to avoid missing neighbors.
// For intersection [l..u] with unvisited intervals, we compute sum of distances for that subinterval
// using arithmetic formulas and remove it from the map.
// If orig_l >= 1 and orig_l % A == 0, push (d + (u - orig_l) + 1, orig_l / A).
//
// Note: This implementation follows a standard competitive-programming approach using Dijkstra
// + interval-compression. It is written for performance and uses 128-bit where needed.
int main(){
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T;
if(!(cin >> T)) return 0;
while(T--){
long long N, A;
cin >> N >> A;
if (A == 1){
// f(K) = N - K, sum = sum_{K=1..N} (N-K) = N*(N-1)/2
__int128 val = (__int128)N * (N - 1) / 2;
long long ans = (long long)(val % MOD);
cout << ans % MOD << '\n';
continue;
}
// intervals: map start -> end of unvisited K
std::map<long long, long long> intervals;
intervals[1] = N;
// min-heap by distance
using P = pair<unsigned long long, long long>; // (dist, node)
priority_queue<P, vector<P>, greater<P>> pq;
pq.push({0ULL, N});
__int128 total = 0; // accumulate in 128-bit then mod at end
// helper lambda to remove intersection [a..b] from intervals and add its contribution
auto remove_and_add = [&](long long a, long long b, unsigned long long d){
// iterate intervals that intersect [a..b] -- in this function we assume exact intersection is single [a..b]
// But actual removal is handled in the main loop.
(void)a; (void)b; (void)d;
};
while(!pq.empty() && !intervals.empty()){
auto [d_u, u] = pq.top(); pq.pop();
if (u < 1) continue;
// orig_l is the nearest (<=u) number divisible by A
long long r = u % A;
long long orig_l = u - r; // orig_l % A == 0 (could be 0)
long long l = orig_l;
if (l < 1) l = 1;
// Special adjustment:
// If u is divisible by A (r==0), we still want to be able to reach u-1, u-2, ...
// The contiguous block we should try to extract starts from max(1, u - (A - 1)) to u in that case.
// This heuristic ensures we don't miss immediate predecessors when u is a multiple of A.
if (r == 0) {
long long cand = u - (A - 1);
if (cand > 1) {
// choose the maximum between current l and cand
if (cand > l) l = cand;
} else {
l = 1;
}
}
// Now we will remove all unvisited parts within [l..u], summing distances.
// Find first interval with start <= u
auto it = intervals.upper_bound(u);
if (it != intervals.begin()) --it; else { /* it at begin */ }
// move it to first interval that may intersect [l..u]
while(it != intervals.end() && it->second < l){
++it;
}
// iterate and remove intersections
vector<pair<long long,long long>> to_add; // for splitting left/right pieces
vector<map<long long,long long>::iterator> to_erase;
while(it != intervals.end() && it->first <= u){
long long a0 = it->first;
long long b0 = it->second;
if (b0 < l) { ++it; continue; }
long long a = max(a0, l);
long long b = min(b0, u);
if (a <= b){
// sum distances for y in [a..b] is sum_{y=a..b} (d_u + (u - y))
// = cnt * d_u + cnt * u - sum_{y=a..b} y
__int128 cnt = (__int128)(b - a + 1);
__int128 sum_y = (__int128)(a + b) * cnt / 2;
__int128 part = (__int128)cnt * (__int128)d_u + (__int128)cnt * (__int128)u - sum_y;
total += part;
// Now adjust interval it: may need to keep left part [a0..a-1] and right part [b+1..b0]
if (a0 < a && b < b0){
// split into two intervals: [a0, a-1] and [b+1, b0]
long long left_s = a0, left_e = a-1;
long long right_s = b+1, right_e = b0;
// record to replace it with left and right
to_add.emplace_back(left_s, left_e);
to_add.emplace_back(right_s, right_e);
to_erase.push_back(it);
++it;
} else if (a0 < a){
// keep left part only
long long left_s = a0, left_e = a-1;
to_add.emplace_back(left_s, left_e);
to_erase.push_back(it);
++it;
} else if (b < b0){
// keep right part only
long long right_s = b+1, right_e = b0;
to_add.emplace_back(right_s, right_e);
to_erase.push_back(it);
++it;
} else {
// entire interval removed
to_erase.push_back(it);
++it;
}
} else {
++it;
}
}
// apply erases and adds
for(auto &it2 : to_erase) intervals.erase(it2);
for(auto &p : to_add) intervals[p.first] = p.second;
// after removing [l..u], attempt to push division result if orig_l >= 1 and orig_l % A == 0
if (orig_l >= 1 && (orig_l % A == 0)){
long long v = orig_l / A;
unsigned long long nd = d_u + (unsigned long long)(u - orig_l) + 1ULL;
// push only if v >= 1
if (v >= 1){
pq.push({nd, v});
}
}
}
// output total % MOD
long long ans = (long long)(total % MOD);
if (ans < 0) ans += MOD;
cout << ans << '\n';
}
return 0;
}