結果

問題 No.3106 Simple Math Problem 3
ユーザー pockyny
提出日時 2025-04-26 23:30:36
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 3,335 bytes
コンパイル時間 1,697 ms
コンパイル使用メモリ 114,576 KB
実行使用メモリ 183,824 KB
最終ジャッジ日時 2025-04-26 23:31:26
合計ジャッジ時間 47,317 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2 WA * 2
other AC * 2 WA * 40
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <vector>
#include <set>
#include <cmath>
#include <atcoder/modint>
#include <cassert>

using namespace std;
using namespace atcoder;
using mint = modint998244353;
typedef long long ll;
int main(){
    ll i,n; cin >> n;
    mint ans = 0;
    vector<pair<pair<ll,ll>,ll>> v;
    {
        ll val = n;
        while(val){
            ll le = n/(val + 1),ri = n/val;
            v.push_back({{le,ri},val});
            val = n/(ri + 1);
        }
    }
    vector<mint> sum;
    sum.push_back(0);
    for(i=0;i<v.size();i++){
        auto [lr,val] = v[i];
        auto [le,ri] = lr;
        sum.push_back(sum.back() + (mint)val*(ri - le));
    }

    // calc sum_i in (0,x] floor(n/i)
    auto f = [&](ll x) -> mint{
        if(x==0) return 0LL;
        pair<pair<ll,ll>,ll> p = {{x,-1},-1};
        ll id = lower_bound(v.begin(),v.end(),p) - v.begin();
        mint a = sum[id];
        mint b = v[id - 1].second*((mint)v[id - 1].first.second - (mint)x);
        return a - b;
    };

    // sum_i in (l,r] n/i
    auto g = [&](ll l,ll r){
        return f(r) - f(l);
    };

    // ll le = 1,ri = n;
    ll le = max(1LL,n - (ll)sqrtl(n) - 10LL),ri = n;
    vector<ll> v_val(ri - le + 1);
    for(ll i=le;i<=ri;i++) v_val[i - le] = i;
    vector<vector<pair<ll,ll>>> fa(ri - le + 1);
    for(i=2;i*i<=n;i++){
        ll l = (le + i - 1)/i*i;
        for(ll j = l - le;j<fa.size();j+=i){
            if(v_val[j]%i){
                break;
            }
            int c = 0;
            while(v_val[j]%i==0){
                v_val[j] /= i;
                c++;
            }
            fa[j].push_back({i,c});
        }
    }

    vector<ll> divs;
    auto dfs = [&](auto &&self,ll x,int id,int len){
        if(len==fa[id].size()){
            divs.push_back(x);
            return;
        }
        for(int y=0;y<=fa[id][len].second;y++){
            self(self,x,id,len + 1);
            // assert((id + le)%fa[id][len].first==0);
            x *= fa[id][len].first;
        }
    };
    for(i=0;i<v_val.size();i++){
        if(v_val[i]>1) fa[i].push_back({v_val[i],1});
    }
    set<ll> s;
    for(ll c=1;c*c<=n;c++){
        ll cn = max(0LL,(n/c - 1 - c));
        mint sum = (-c + 1)*cn;
        if(cn) sum += g(c,n/c - 1);
        // for(ll b=c + 1;b<=n/c - 1;b++){
        //     if(n%b<c) sum--;
        // }

        // n%b<cである ⇔ n%b==x かつ x<cなので、n%b==xなるbをいれて余分なものを省いていく
        // n%b==x ⇔ b|(n - x)かつb>xなので(n - x)の約数列挙をする (最初に列挙)
        if(n - (c - 1) - le>=0){
            divs.clear();
            dfs(dfs,1LL,(n - (c - 1) - le),0);
            divs.clear();
            // cout << divs.size() << " " << n - (c - 1) << endl;
            // for(ll x:divs) cout << x << " ";
            // cout << "\n";
            for(ll d:divs){
                // assert((n - (c - 1))%d==0);
                if(c + 1<=d && d<n/c){
                    // assert(n%d==(c - 1));
                    s.insert(d);
                }    
            }
        }
        while(s.size() && *s.begin()<=c){
            s.erase(s.begin());
        }
        while(s.size() && *s.rbegin()>=n/c){
            s.erase(*s.rbegin());
        }
        sum -= s.size();
        ans += c*sum;
    }
    cout << ans.val() << "\n";
}
0