結果

問題 No.3505 Sum of Prod of Root
コンテスト
ユーザー Saku0512
提出日時 2026-04-18 02:44:49
言語 C++23
(gcc 15.2.0 + boost 1.89.0)
コンパイル:
g++-15 -O2 -lm -std=c++23 -Wuninitialized -DONLINE_JUDGE -o a.out _filename_
実行:
./a.out
結果
AC  
実行時間 359 ms / 2,000 ms
コード長 3,991 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 3,068 ms
コンパイル使用メモリ 341,696 KB
実行使用メモリ 11,768 KB
最終ジャッジ日時 2026-04-18 02:44:55
合計ジャッジ時間 5,203 ms
ジャッジサーバーID
(参考情報)
judge2_0 / judge1_1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 13
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

#include <bits/stdc++.h>
using namespace std;

const long long MOD = 998244353;
const unsigned long long INF = 2000000000000000000ULL; // 2 * 10^18

long long power(long long base, long long exp) {
    long long res = 1;
    base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = res * base % MOD;
        base = base * base % MOD;
        exp /= 2;
    }
    return res;
}

long long inv2, inv6, inv30;

inline long long sum2(long long n) {
    long long nm = n % MOD;
    long long res = nm * (nm + 1) % MOD * (2 * nm + 1) % MOD;
    return res * inv6 % MOD;
}

inline long long sum3(long long n) {
    long long nm = n % MOD;
    long long res = nm * (nm + 1) % MOD * inv2 % MOD;
    return res * res % MOD;
}

inline long long sum4(long long n) {
    long long nm = n % MOD;
    long long res1 = nm * (nm + 1) % MOD * (2 * nm + 1) % MOD;
    long long res2 = (3 * nm % MOD * nm % MOD + 3 * nm % MOD - 1 + MOD) % MOD;
    return res1 * res2 % MOD * inv30 % MOD;
}

// べき乗計算(オーバーフロー対策で INF で頭打ちにする)
unsigned long long safe_pow(unsigned long long base, int exp) {
    unsigned __int128 res = 1;
    for (int i = 0; i < exp; ++i) {
        res *= base;
        if (res > INF) return INF;
    }
    return (unsigned long long)res;
}

// S(x) の計算
long long calc_S(long long x) {
    if (x <= 0) return 0;
    
    // sqrtl を使って高速に平方根の近似を求め、微調整する (O(1))
    long long m = sqrtl(x);
    while ((unsigned __int128)(m + 1) * (m + 1) <= x) m++;
    while ((unsigned __int128)m * m > x) m--;
    
    long long n = m - 1;
    long long Sa = 0;
    if (n > 0) {
        Sa = (2 * sum4(n) + 3 * sum3(n) + sum2(n)) % MOD;
    }
    
    // オーバーフロー防止のため __int128 経由で安全に MOD を取る
    long long count = (long long)((x - (unsigned __int128)m * m + 1) % MOD);
    long long first = (long long)(((unsigned __int128)m * m) % MOD);
    long long last = x % MOD;
    
    long long sum_i = (first + last) % MOD * count % MOD * inv2 % MOD;
    long long Sb = (m % MOD) * sum_i % MOD;
    
    return (Sa + Sb) % MOD;
}

int main() {
    // 高速化
    cin.tie(0)->sync_with_stdio(0);

    inv2 = power(2, MOD - 2);
    inv6 = power(6, MOD - 2);
    inv30 = power(30, MOD - 2);

    long long N;
    if (!(cin >> N)) return 0;

    // 区間の変化点(累乗数)を列挙
    vector<long long> P;
    P.push_back(1);
    for (long long x = 2; x <= 1000000; ++x) {
        long long v = x * x * x;
        P.push_back(v);
        while (v <= 1000000000000000000LL / x) {
            v *= x;
            P.push_back(v);
        }
    }
    P.push_back(N + 1);
    sort(P.begin(), P.end());
    P.erase(unique(P.begin(), P.end()), P.end());

    long long ans = 0;
    
    // 各 k における L^{1/k} の現在値と、次に値が変わる境界値
    long long root[60];
    unsigned long long next_p[60];
    for (int k = 3; k <= 59; ++k) {
        root[k] = 1;
        next_p[k] = safe_pow(2, k);
    }

    for (size_t i = 0; i < P.size() - 1; ++i) {
        long long L = P[i];
        long long R = P[i+1] - 1;
        if (L > N) break;
        if (R > N) R = N;
        if (L > R) continue;

        long long h = 1;
        // k = 3 〜 59 についての積を O(1) 感覚で求める
        for (int k = 3; k <= 59; ++k) {
            // L が次の境界を超えたら root を更新
            while ((unsigned long long)L >= next_p[k]) {
                root[k]++;
                next_p[k] = safe_pow(root[k] + 1, k);
            }
            // root が 1 になったら、それ以降の k も絶対に 1 なので計算を打ち切る(最強の定数倍高速化)
            if (root[k] == 1) break;
            
            h = h * (root[k] % MOD) % MOD;
        }

        long long sum_val = (calc_S(R) - calc_S(L - 1) + MOD) % MOD;
        ans = (ans + h * sum_val) % MOD;
    }

    cout << ans << '\n';

    return 0;
}
0