結果

問題 No.3118 Increment or Multiply
ユーザー shibh308
提出日時 2025-04-19 15:35:11
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 2,258 bytes
コンパイル時間 2,108 ms
コンパイル使用メモリ 197,256 KB
実行使用メモリ 7,848 KB
最終ジャッジ日時 2025-04-19 15:35:15
合計ジャッジ時間 3,940 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 5 WA * 30
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll MOD = 998244353;

struct mint {
    ll v;
    mint(): v(0) {}
    mint(ll _v) { v = (_v % MOD + MOD) % MOD; }
    mint& operator+=(const mint& o) { v = v + o.v; if (v >= MOD) v -= MOD; return *this; }
    mint& operator-=(const mint& o) { v = v - o.v; if (v < 0) v += MOD; return *this; }
    mint& operator*=(const mint& o) { v = (ll)( ( __int128 )v * o.v % MOD ); return *this; }
    friend mint operator+(mint a, const mint& b) { return a += b; }
    friend mint operator-(mint a, const mint& b) { return a -= b; }
    friend mint operator*(mint a, const mint& b) { return a *= b; }
};

int main(){
    ios::sync_with_stdio(false);
    cin.tie(NULL);

    int T;
    if(!(cin >> T)) return 0;
    while(T--){
        ll N, A;
        cin >> N >> A;
        mint ans = 0;
        if(A == 1){
            // sum_{K=1..N} (N-K) = N*(N-1)/2
            mint mN = N;
            ans = mN * mint(N-1) * mint((MOD+1)/2);
        } else {
            // precompute powers of A
            vector<unsigned long long> pw;
            pw.reserve(64);
            pw.push_back(1);
            while(true) {
                __uint128_t nxt = ( __uint128_t )pw.back() * A;
                if(nxt > ( __uint128_t )N) break;
                pw.push_back((unsigned long long)nxt);
            }
            int Tpw = pw.size();
            for(int t = 0; t < Tpw; ++t){
                ll p = pw[t];
                ll R = N / p;
                ll L = 0;
                if(t+1 < Tpw) L = N / pw[t+1];
                // else pw[t+1] > N so N/pw[t+1] = 0
                ll cnt = R - L;
                if(cnt <= 0) continue;
                // sum f = cnt*t + cnt*N - p * sum_{K=L+1..R} K
                // sumK = (L+1 + R)*cnt/2
                mint mcnt = cnt;
                mint contrib = mcnt * mint(t) + mcnt * mint(N);
                // sum K
                // use __int128 to compute sum safely mod
                __int128 s = (__int128)(L+1 + R) * cnt / 2;
                mint sumK = (long long)(s % MOD);
                contrib -= mint(p) * sumK;
                ans += contrib;
            }
        }
        cout << ans.v;
        if(T) cout << '\n';
    }
    return 0;
}
0