結果

問題 No.3118 Increment or Multiply
ユーザー shibh308
提出日時 2025-04-19 15:41:31
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 32 ms / 2,000 ms
コード長 2,512 bytes
コンパイル時間 2,374 ms
コンパイル使用メモリ 198,936 KB
実行使用メモリ 7,844 KB
最終ジャッジ日時 2025-04-19 15:41:36
合計ジャッジ時間 4,508 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 35
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll MOD = 998244353;

// modint
struct mint {
    ll v;
    mint(): v(0) {}
    mint(ll _v) { v = (_v % MOD + MOD) % MOD; }
    mint& operator+=(const mint& o) { v += o.v; if (v >= MOD) v -= MOD; return *this; }
    mint& operator-=(const mint& o) { v -= o.v; if (v < 0) v += MOD; return *this; }
    mint& operator*=(const mint& o) { v = (ll)((__int128)v * o.v % MOD); return *this; }
    friend mint operator+(mint a, const mint& b) { return a += b; }
    friend mint operator-(mint a, const mint& b) { return a -= b; }
    friend mint operator*(mint a, const mint& b) { return a *= b; }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int T;
    cin >> T;
    while (T--) {
        ll N, A;
        cin >> N >> A;
        mint ans = 0;
        if (A == 1) {
            // f(K) = N - K, sum = N*(N-1)/2
            mint mN = N;
            mint inv2 = (MOD + 1) / 2;
            ans = mN * mint(N - 1) * inv2;
        } else {
            // base-A digits of N (LSB first)
            vector<ll> digits;
            ll tmp = N;
            while (tmp > 0) {
                digits.push_back(tmp % A);
                tmp /= A;
            }
            int D = digits.size();
            // prefix sums of digits: S[t] = sum_{j=0..t-1} digits[j]
            vector<ll> S(D + 1, 0);
            for (int i = 1; i <= D; ++i) {
                S[i] = S[i - 1] + digits[i - 1];
            }
            mint inv2 = (MOD + 1) / 2;
            // compute powers of A up to N safely
            vector<unsigned long long> pw;
            pw.reserve(D);
            pw.push_back(1);
            for (int i = 1; i <= D; ++i) {
                __uint128_t nxt = ( __uint128_t )pw.back() * A;
                if (nxt > ( __uint128_t )N) break;
                pw.push_back((unsigned long long)nxt);
            }
            int Tpw = pw.size();
            // sum over segments
            for (int t = 0; t < Tpw; ++t) {
                ll R = N / pw[t];
                ll L = (t + 1 < Tpw ? N / pw[t + 1] : 0);
                ll cnt = R - L;
                if (cnt <= 0) continue;
                // sum f = cnt*(t + S[t]) + cnt*(cnt-1)/2
                mint mcnt = cnt;
                mint base = mcnt * mint(t + S[t]);
                mint tail = mcnt * mint(cnt - 1) * inv2;
                ans += base + tail;
            }
        }
        cout << ans.v;
        if (T) cout << '\n';
    }
    return 0;
}
0