結果

問題 No.3505 Sum of Prod of Root
コンテスト
ユーザー Naru820
提出日時 2026-03-17 15:23:13
言語 C++23
(gcc 15.2.0 + boost 1.89.0)
コンパイル:
g++-15 -O2 -lm -std=c++23 -Wuninitialized -DONLINE_JUDGE -o a.out _filename_
実行:
./a.out
結果
AC  
実行時間 450 ms / 2,000 ms
コード長 4,097 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 2,174 ms
コンパイル使用メモリ 191,020 KB
実行使用メモリ 28,024 KB
最終ジャッジ日時 2026-04-17 19:40:27
合計ジャッジ時間 5,142 ms
ジャッジサーバーID
(参考情報)
judge2_0 / judge3_1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 13
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

// gemini3.1 pro
#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>

using namespace std;

const long long MOD = 998244353;

// 繰り返し二乗法による累乗と逆元
long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}
long long modInverse(long long n) { return power(n, MOD - 2); }

long long inv2, inv6, inv30;
void initInverses() {
    inv2 = modInverse(2);
    inv6 = modInverse(6);
    inv30 = modInverse(30);
}

// べき乗和の公式
long long S2(long long x) {
    x %= MOD;
    return x * (x + 1) % MOD * (2 * x + 1) % MOD * inv6 % MOD;
}
long long S3(long long x) {
    x %= MOD;
    long long res = x * (x + 1) % MOD * inv2 % MOD;
    return res * res % MOD;
}
long long S4(long long x) {
    x %= MOD;
    long long res = x * (x + 1) % MOD * (2 * x + 1) % MOD;
    long long term = (3 * x % MOD * (x + 1) % MOD - 1 + MOD) % MOD;
    return res * term % MOD * inv30 % MOD;
}

// S(M) = \sum_{i=1}^M i * floor(sqrt(i)) を O(1) で計算
long long S(long long M) {
    if (M <= 0) return 0;
    long long V = sqrt(M);
    while ((V + 1) * (V + 1) <= M) V++;
    while (V * V > M) V--;
    
    long long v1 = V - 1;
    long long part1 = (2 * S4(v1) % MOD + 3 * S3(v1) % MOD + S2(v1)) % MOD;
    
    long long M_mod = M % MOD;
    long long V2_mod = (V % MOD) * (V % MOD) % MOD;
    long long count = (M_mod - V2_mod + 1 + MOD) % MOD;
    long long sum_i = (V2_mod + M_mod) % MOD * count % MOD * inv2 % MOD;
    
    long long part2 = (V % MOD) * sum_i % MOD;
    
    return (part1 + part2) % MOD;
}

// 値が変化する「イベント」を管理する構造体
struct Event {
    long long val; // 変化するタイミングの i (a^k)
    int k;         // どの k 乗根が変化するか
    long long a;   // 変化後の値
    
    bool operator<(const Event& other) const {
        return val < other.val;
    }
};

int main() {
    // 高速入出力
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);
    
    initInverses();
    long long N;
    if (!(cin >> N)) return 0;
    
    // イベントの列挙
    vector<Event> events;
    for (int k = 3; k <= 60; ++k) {
        long long a = 2;
        while (true) {
            __int128 p = 1;
            for (int i = 0; i < k; ++i) p *= a;
            
            if (p > N) break; // N を超えたら次の k へ
            
            events.push_back({(long long)p, k, a});
            a++;
        }
    }
    
    // i が小さい順にイベントをソート
    sort(events.begin(), events.end());
    
    long long ans = 0;
    long long current_L = 1;
    
    // v[k] = 現在の floor( i^(1/k) ) の値。初期値はすべて 1
    vector<long long> v(65, 1);
    
    size_t idx = 0;
    while (idx < events.size()) {
        long long V = events[idx].val; // 次の変化点
        
        // 変化点に達するまでの区間 [current_L, V-1] の和を計算して加算
        if (current_L < V) {
            long long H = 1;
            for (int k = 3; k <= 60; ++k) {
                H = (H * (v[k] % MOD)) % MOD;
            }
            long long sum_interval = (S(V - 1) - S(current_L - 1) + MOD) % MOD;
            ans = (ans + H * sum_interval) % MOD;
            
            current_L = V;
        }
        
        // 同じ i (V) で複数の k 乗根が同時に変化する可能性があるため、まとめて更新
        while (idx < events.size() && events[idx].val == V) {
            v[events[idx].k] = events[idx].a;
            idx++;
        }
    }
    
    // 最後に残った区間 [current_L, N] の和を加算
    if (current_L <= N) {
        long long H = 1;
        for (int k = 3; k <= 60; ++k) {
            H = (H * (v[k] % MOD)) % MOD;
        }
        long long sum_interval = (S(N) - S(current_L - 1) + MOD) % MOD;
        ans = (ans + H * sum_interval) % MOD;
    }
    
    cout << ans << "\n";
    return 0;
}
0