結果

問題 No.3505 Sum of Prod of Root
コンテスト
ユーザー Saku0512
提出日時 2026-04-18 02:41:29
言語 C++23
(gcc 15.2.0 + boost 1.89.0)
コンパイル:
g++-15 -O2 -lm -std=c++23 -Wuninitialized -DONLINE_JUDGE -o a.out _filename_
実行:
./a.out
結果
TLE  
実行時間 -
コード長 4,504 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 1,786 ms
コンパイル使用メモリ 190,260 KB
実行使用メモリ 18,276 KB
最終ジャッジ日時 2026-04-18 02:42:13
合計ジャッジ時間 8,448 ms
ジャッジサーバーID
(参考情報)
judge1_0 / judge2_0
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 1
other AC * 4 TLE * 1 -- * 8
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

#include <iostream>
#include <vector>
#include <cmath>
#include <algorithm>

using namespace std;

long long MOD = 998244353;

long long calculate_power(long long base, long long exp) {
    long long result = 1;
    base = base % MOD;
    while (exp > 0) {
        if (exp % 2 == 1) {
            result = (result * base) % MOD;
        }
        base = (base * base) % MOD;
        exp = exp / 2;
    }
    return result;
}

long long inv2, inv6, inv30;

long long sum_of_squares(long long n) {
    n = n % MOD;
    long long res = n * (n + 1) % MOD;
    res = res * (2 * n + 1) % MOD;
    return (res * inv6) % MOD;
}

long long sum_of_cubes(long long n) {
    n = n % MOD;
    long long res = n * (n + 1) % MOD;
    res = (res * inv2) % MOD;
    return (res * res) % MOD;
}

long long sum_of_power4(long long n) {
    n = n % MOD;
    long long res1 = n * (n + 1) % MOD;
    res1 = res1 * (2 * n + 1) % MOD;
    
    long long res2 = (3 * n % MOD * n % MOD) + (3 * n % MOD) - 1 + MOD;
    res2 = res2 % MOD;
    
    long long total = (res1 * res2) % MOD;
    return (total * inv30) % MOD;
}

long long get_integer_root(long long n, int k) {
    if (n <= 1) return n;
    
    long long r;
    if (k == 2) {
        r = round(sqrt(n));
    } else {
        r = round(pow(n, 1.0 / k));
    }
    
    while (true) {
        unsigned __int128 result = 1;
        bool is_ok = true;
        
        for (int i = 0; i < k; i++) {
            result = result * r;
            if (result > n) {
                is_ok = false;
                break;
            }
        }
        
        if (is_ok == false) {
            r--;
        } else {
            unsigned __int128 result_next = 1;
            bool is_next_ok = true;
            for (int i = 0; i < k; i++) {
                result_next = result_next * (r + 1);
                if (result_next > n) {
                    is_next_ok = false;
                    break;
                }
            }
            if (is_next_ok == true) {
                r++;
            } else {
                break;
            }
        }
    }
    return r;
}

long long calculate_S(long long x) {
    if (x <= 0) return 0;
    
    long long m = get_integer_root(x, 2);
    long long n = m - 1;
    
    long long Sa = 0;
    if (n > 0) {
        long long part1 = (2 * sum_of_power4(n)) % MOD;
        long long part2 = (3 * sum_of_cubes(n)) % MOD;
        long long part3 = sum_of_squares(n) % MOD;
        Sa = (part1 + part2 + part3) % MOD;
    }
    
    long long count = (x - m * m + 1) % MOD;
    long long first = (m * m) % MOD;
    long long last = x % MOD;
    
    long long sum_i = (first + last) % MOD;
    sum_i = (sum_i * count) % MOD;
    sum_i = (sum_i * inv2) % MOD;
    
    long long Sb = (m % MOD) * sum_i % MOD;
    
    return (Sa + Sb) % MOD;
}

int main() {
    inv2 = calculate_power(2, MOD - 2);
    inv6 = calculate_power(6, MOD - 2);
    inv30 = calculate_power(30, MOD - 2);

    long long N;
    cin >> N;

    vector<long long> change_points;
    change_points.push_back(1);
    
    for (long long x = 2; x <= 1000000; x++) {
        long long v = x * x * x;
        change_points.push_back(v);
        
        while (v <= 1000000000000000000LL / x) {
            v = v * x;
            change_points.push_back(v);
        }
    }
    change_points.push_back(N + 1);
    
    sort(change_points.begin(), change_points.end());
    change_points.erase(unique(change_points.begin(), change_points.end()), change_points.end());

    long long total_answer = 0;
    
    for (int i = 0; i < change_points.size() - 1; i++) {
        long long left_bound = change_points[i];
        long long right_bound = change_points[i + 1] - 1;
        
        if (left_bound > N) {
            break;
        }
        if (right_bound > N) {
            right_bound = N;
        }
        if (left_bound > right_bound) {
            continue;
        }

        long long current_product = 1;
        for (int k = 3; k <= 59; k++) {
            long long root_val = get_integer_root(left_bound, k);
            if (root_val == 1) {
                break;
            }
            current_product = (current_product * (root_val % MOD)) % MOD;
        }

        long long sum_val = (calculate_S(right_bound) - calculate_S(left_bound - 1) + MOD) % MOD;
        
        long long add_val = (current_product * sum_val) % MOD;
        total_answer = (total_answer + add_val) % MOD;
    }

    cout << total_answer << endl;

    return 0;
}
0