結果

問題 No.3505 Sum of Prod of Root
コンテスト
ユーザー Enderaoe Lyther
提出日時 2026-04-17 21:56:19
言語 C++23
(gcc 15.2.0 + boost 1.89.0)
コンパイル:
g++-15 -O2 -lm -std=c++23 -Wuninitialized -DONLINE_JUDGE -o a.out _filename_
実行:
./a.out
結果
AC  
実行時間 272 ms / 2,000 ms
コード長 3,903 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 1,967 ms
コンパイル使用メモリ 193,848 KB
実行使用メモリ 43,900 KB
最終ジャッジ日時 2026-04-17 21:56:28
合計ジャッジ時間 4,136 ms
ジャッジサーバーID
(参考情報)
judge2_1 / judge3_0
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 13
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <utility>
#include <vector>

using namespace std;

using int64 = long long;
using i128 = __int128_t;

static constexpr int64 MOD = 998244353;
static constexpr int64 INV2 = (MOD + 1) / 2;
static constexpr int64 INV6 = 166374059;   // 6^{-1} mod MOD
static constexpr int64 INV30 = 432572553;  // 30^{-1} mod MOD

int64 mul_mod(int64 a, int64 b) {
    return (i128)a * b % MOD;
}

int64 add_mod(int64 a, int64 b) {
    a += b;
    if (a >= MOD) a -= MOD;
    return a;
}

int64 sub_mod(int64 a, int64 b) {
    a -= b;
    if (a < 0) a += MOD;
    return a;
}

bool le_pow(int64 base, int k, int64 limit) {
    i128 cur = 1;
    for (int i = 0; i < k; ++i) {
        cur *= base;
        if (cur > limit) return false;
    }
    return true;
}

int64 kth_root_floor(int64 n, int k) {
    if (k == 1 || n <= 1) return n;
    long double x = pow((long double)n, 1.0L / k);
    int64 r = (int64)x;
    if (r < 1) r = 1;
    while (le_pow(r + 1, k, n)) ++r;
    while (!le_pow(r, k, n)) --r;
    return r;
}

int64 isqrt_floor(int64 n) {
    int64 r = (int64)sqrt((long double)n);
    while ((i128)(r + 1) * (r + 1) <= n) ++r;
    while ((i128)r * r > n) --r;
    return r;
}

int64 tri(int64 n) {
    if (n <= 0) return 0;
    return mul_mod(mul_mod(n % MOD, (n + 1) % MOD), INV2);
}

int64 sum2(int64 n) {
    if (n <= 0) return 0;
    int64 a = n % MOD;
    int64 b = (n + 1) % MOD;
    int64 c = (2 * (n % MOD) + 1) % MOD;
    return mul_mod(mul_mod(a, b), mul_mod(c, INV6));
}

int64 sum3(int64 n) {
    if (n <= 0) return 0;
    int64 t = tri(n);
    return mul_mod(t, t);
}

int64 sum4(int64 n) {
    if (n <= 0) return 0;
    int64 a = n % MOD;
    int64 b = (n + 1) % MOD;
    int64 c = (2 * (n % MOD) + 1) % MOD;
    int64 d = (3 * mul_mod(a, a) + 3 * a - 1) % MOD;
    if (d < 0) d += MOD;
    return mul_mod(mul_mod(mul_mod(a, b), c), mul_mod(d, INV30));
}

// W(x) = sum_{n=1}^{x} n * floor(sqrt(n))
int64 prefix_weight(int64 x) {
    if (x <= 0) return 0;

    int64 s = isqrt_floor(x);
    int64 m = s - 1;

    int64 full = 0;
    full = add_mod(full, mul_mod(2, sum4(m)));
    full = add_mod(full, mul_mod(3, sum3(m)));
    full = add_mod(full, sum2(m));

    int64 tail = mul_mod(s % MOD, sub_mod(tri(x), tri((int64)((i128)s * s - 1))));
    return add_mod(full, tail);
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int64 N;
    cin >> N;

    int64 max_base = kth_root_floor(N, 3);
    vector<int64> inv(max<int64>(max_base + 1, 2), 1);
    for (int64 i = 2; i <= max_base; ++i) {
        inv[i] = MOD - mul_mod(MOD / i, inv[MOD % i]);
    }

    vector<pair<int64, int64>> events;
    for (int k = 3; k <= 60; ++k) {
        int64 lim = kth_root_floor(N, k);
        for (int64 b = 2; b <= lim; ++b) {
            i128 p = 1;
            for (int i = 0; i < k; ++i) p *= b;
            events.push_back({(int64)p, mul_mod(b % MOD, inv[b - 1])});
        }
    }

    sort(events.begin(), events.end());

    vector<pair<int64, int64>> merged;
    merged.reserve(events.size());
    for (auto [x, mul] : events) {
        if (!merged.empty() && merged.back().first == x) {
            merged.back().second = mul_mod(merged.back().second, mul);
        } else {
            merged.push_back({x, mul});
        }
    }

    int64 ans = 0;
    int64 cur = 1;
    int64 h = 1;  // current value of H(n) on the active interval

    for (auto [x, mul] : merged) {
        if (cur <= x - 1) {
            int64 segment = sub_mod(prefix_weight(x - 1), prefix_weight(cur - 1));
            ans = add_mod(ans, mul_mod(h, segment));
        }
        h = mul_mod(h, mul);
        cur = x;
    }

    if (cur <= N) {
        int64 segment = sub_mod(prefix_weight(N), prefix_weight(cur - 1));
        ans = add_mod(ans, mul_mod(h, segment));
    }

    cout << ans << '\n';
    return 0;
}
0