結果

問題 No.3505 Sum of Prod of Root
コンテスト
ユーザー 왕지후
提出日時 2026-04-18 01:50:34
言語 C++23
(gcc 15.2.0 + boost 1.89.0)
コンパイル:
g++-15 -O2 -lm -std=c++23 -Wuninitialized -DONLINE_JUDGE -o a.out _filename_
実行:
./a.out
結果
AC  
実行時間 223 ms / 2,000 ms
コード長 3,454 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 763 ms
コンパイル使用メモリ 76,288 KB
実行使用メモリ 35,840 KB
最終ジャッジ日時 2026-04-18 01:50:56
合計ジャッジ時間 2,319 ms
ジャッジサーバーID
(参考情報)
judge1_1 / judge2_1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 13
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <math.h>

typedef long long ll;
typedef __int128 i128;

#define MOD 998244353LL
#define MAXA 1000005
#define MAXE 1100000

typedef struct {
    ll x;
    ll mul;
} Event;

static Event ev[MAXE];
static ll inv[MAXA];
static int ecnt = 0;

static inline ll norm(ll x) {
    x %= MOD;
    if (x < 0) x += MOD;
    return x;
}

static ll mod_pow(ll a, ll e) {
    ll r = 1;
    while (e > 0) {
        if (e & 1) r = (ll)((i128)r * a % MOD);
        a = (ll)((i128)a * a % MOD);
        e >>= 1;
    }
    return r;
}

static int cmp_event(const void *pa, const void *pb) {
    ll a = ((const Event *)pa)->x;
    ll b = ((const Event *)pb)->x;
    return (a < b) ? -1 : (a > b);
}

static inline ll sum1(ll n) {
    n %= MOD;
    return (ll)((i128)n * ((n + 1) % MOD) % MOD * ((MOD + 1) / 2) % MOD);
}

static inline ll sum2(ll n) {
    static const ll INV6 = 166374059LL;
    ll a = n % MOD;
    ll b = (n + 1) % MOD;
    ll c = (2 * a + 1) % MOD;
    return (ll)((i128)a * b % MOD * c % MOD * INV6 % MOD);
}

static inline ll sum3(ll n) {
    static const ll INV4 = 748683265LL;
    ll a = n % MOD;
    ll b = (n + 1) % MOD;
    ll t = (ll)((i128)a * b % MOD);
    return (ll)((i128)t * t % MOD * INV4 % MOD);
}

static inline ll sum4(ll n) {
    static const ll INV30 = 432572553LL;
    ll a = n % MOD;
    ll b = (n + 1) % MOD;
    ll c = (2 * a + 1) % MOD;
    ll d = norm((ll)((3 * (i128)a * a) % MOD) + 3 * a - 1);
    return (ll)((i128)a * b % MOD * c % MOD * d % MOD * INV30 % MOD);
}

static inline ll isqrt_ll(ll n) {
    long double x = sqrtl((long double)n);
    ll r = (ll)x;
    while ((i128)(r + 1) * (r + 1) <= n) r++;
    while ((i128)r * r > n) r--;
    return r;
}

/*
  F(n) = sum_{i=1}^n i * floor(sqrt(i))
*/
static ll pref(ll n) {
    if (n <= 0) return 0;

    ll m = isqrt_ll(n);
    ll k = m - 1;

    ll full = 0;
    if (k >= 1) {
        ll s2 = sum2(k);
        ll s3 = sum3(k);
        ll s4 = sum4(k);
        full = norm(2 * s4 + 3 * s3 + s2);
    }

    ll mm = m % MOD;
    ll left = (ll)((i128)m * m - 1);
    ll tail = norm(sum1(n) - sum1(left));
    tail = (ll)((i128)tail * mm % MOD);

    return norm(full + tail);
}

int main(void) {
    ll N;
    scanf("%lld", &N);

    inv[1] = 1;
    for (int i = 2; i < MAXA; i++) {
        inv[i] = MOD - (ll)((i128)(MOD / i) * inv[MOD % i] % MOD);
    }

    for (int k = 3; k <= 60; k++) {
        for (ll a = 2;; a++) {
            i128 p = 1;
            for (int t = 0; t < k; t++) {
                p *= a;
                if (p > N) break;
            }
            if (p > N) break;

            ll x = (ll)p;
            ll mul = (ll)((i128)(a % MOD) * inv[a - 1] % MOD);

            ev[ecnt].x = x;
            ev[ecnt].mul = mul;
            ecnt++;
        }
    }

    qsort(ev, ecnt, sizeof(Event), cmp_event);

    ll ans = 0;
    ll H = 1;
    ll prev = 1;
    int i = 0;

    while (i < ecnt) {
        ll x = ev[i].x;

        if (prev <= x - 1) {
            ll seg = norm(pref(x - 1) - pref(prev - 1));
            ans = norm(ans + (ll)((i128)H * seg % MOD));
        }

        while (i < ecnt && ev[i].x == x) {
            H = (ll)((i128)H * ev[i].mul % MOD);
            i++;
        }

        prev = x;
    }

    if (prev <= N) {
        ll seg = norm(pref(N) - pref(prev - 1));
        ans = norm(ans + (ll)((i128)H * seg % MOD));
    }

    printf("%lld\n", ans);
    return 0;
}
0