結果
問題 |
No.3118 Increment or Multiply
|
ユーザー |
|
提出日時 | 2025-04-19 15:41:31 |
言語 | C++17 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 32 ms / 2,000 ms |
コード長 | 2,512 bytes |
コンパイル時間 | 2,374 ms |
コンパイル使用メモリ | 198,936 KB |
実行使用メモリ | 7,844 KB |
最終ジャッジ日時 | 2025-04-19 15:41:36 |
合計ジャッジ時間 | 4,508 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 35 |
ソースコード
#include <bits/stdc++.h> using namespace std; typedef long long ll; const ll MOD = 998244353; // modint struct mint { ll v; mint(): v(0) {} mint(ll _v) { v = (_v % MOD + MOD) % MOD; } mint& operator+=(const mint& o) { v += o.v; if (v >= MOD) v -= MOD; return *this; } mint& operator-=(const mint& o) { v -= o.v; if (v < 0) v += MOD; return *this; } mint& operator*=(const mint& o) { v = (ll)((__int128)v * o.v % MOD); return *this; } friend mint operator+(mint a, const mint& b) { return a += b; } friend mint operator-(mint a, const mint& b) { return a -= b; } friend mint operator*(mint a, const mint& b) { return a *= b; } }; int main() { ios::sync_with_stdio(false); cin.tie(NULL); int T; cin >> T; while (T--) { ll N, A; cin >> N >> A; mint ans = 0; if (A == 1) { // f(K) = N - K, sum = N*(N-1)/2 mint mN = N; mint inv2 = (MOD + 1) / 2; ans = mN * mint(N - 1) * inv2; } else { // base-A digits of N (LSB first) vector<ll> digits; ll tmp = N; while (tmp > 0) { digits.push_back(tmp % A); tmp /= A; } int D = digits.size(); // prefix sums of digits: S[t] = sum_{j=0..t-1} digits[j] vector<ll> S(D + 1, 0); for (int i = 1; i <= D; ++i) { S[i] = S[i - 1] + digits[i - 1]; } mint inv2 = (MOD + 1) / 2; // compute powers of A up to N safely vector<unsigned long long> pw; pw.reserve(D); pw.push_back(1); for (int i = 1; i <= D; ++i) { __uint128_t nxt = ( __uint128_t )pw.back() * A; if (nxt > ( __uint128_t )N) break; pw.push_back((unsigned long long)nxt); } int Tpw = pw.size(); // sum over segments for (int t = 0; t < Tpw; ++t) { ll R = N / pw[t]; ll L = (t + 1 < Tpw ? N / pw[t + 1] : 0); ll cnt = R - L; if (cnt <= 0) continue; // sum f = cnt*(t + S[t]) + cnt*(cnt-1)/2 mint mcnt = cnt; mint base = mcnt * mint(t + S[t]); mint tail = mcnt * mint(cnt - 1) * inv2; ans += base + tail; } } cout << ans.v; if (T) cout << '\n'; } return 0; }