結果

問題 No.3505 Sum of Prod of Root
コンテスト
ユーザー よいちなすの
提出日時 2026-04-18 15:15:06
言語 C++23
(gcc 15.2.0 + boost 1.89.0)
コンパイル:
g++-15 -O2 -lm -std=c++23 -Wuninitialized -DONLINE_JUDGE -o a.out _filename_
実行:
./a.out
結果
WA  
実行時間 -
コード長 2,468 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 2,635 ms
コンパイル使用メモリ 346,736 KB
実行使用メモリ 27,044 KB
最終ジャッジ日時 2026-04-18 15:15:29
合計ジャッジ時間 8,133 ms
ジャッジサーバーID
(参考情報)
judge1_1 / judge2_0
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 1
other WA * 3 TLE * 1 -- * 9
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

#include<bits/stdc++.h>
using namespace std;

typedef long long ll;
typedef __int128_t int128;

const ll MOD = 998244353;

ll power(ll base, ll exp) {
    ll res = 1; base %= MOD;
    while (exp > 0) {
        if (exp % 2 == 1) res = (res * base) % MOD;
        base = (base * base) % MOD;
        exp /= 2;
    }
    return res;
}

ll modInverse(ll n) { return power(n, MOD - 2); }

const ll INV2 = 499122177;
const ll INV6 = 166374059;
const ll INV30 = 33274812;

ll sum_i(ll a, ll b) {
    if (a > b) return 0;
    int128 n = (int128)b - a + 1;
    return (ll)(((int128)(a + b) * n / 2) % MOD);
}

ll sum(ll n) {
    if (n <= 0) return 0;
    ll v = n % MOD;
    ll s2 = v * (v + 1) % MOD * (2 * v + 1) % MOD * INV6 % MOD;
    ll s3 = (v * (v + 1) % MOD * INV2 % MOD); s3 = s3 * s3 % MOD;
    ll X = (3 * v % MOD * v % MOD + 3 * v % MOD - 1 + MOD) % MOD;
    ll s4 = v * (v + 1) % MOD * (2 * v + 1) % MOD * X % MOD * INV30 % MOD;
    return (2 * s4 + 3 * s3 + s2) % MOD;
}

ll sqr(ll n) {
    if (n <= 0) return 0;
    ll x = sqrtl(n);
    while ((int128)(x + 1) * (x + 1) <= n) x++;
    while ((int128)x * x > n) x--;
    return x;
}

ll sol(ll n, int k) {
    if (n <= 1) return n;
    ll r = pow(n, 1.0 / k);
    r += 2;
    while (true) {
        int128 p = 1; bool ok = true;
        for (int i = 0; i < k; i++) { p *= r; if (p > n) { ok = false; break; } }
        if (ok) return r;
        r--;
    }
}

int main() {
    ll N; cin >> N;
    set<ll> pts = {1, N + 1};
    for (int k = 3; k <= 60; k++) {
        for (ll m = 2; ; m++) {
            int128 p = 1; bool over = false;
            for (int i = 0; i < k; i++) { p *= m; if (p > N) { over = true; break; } }
            if (over) break;
            pts.insert((ll)p);
        }
    }
    vector<ll> P(pts.begin(), pts.end());
    ll ans = 0;
    for (size_t j = 0; j < P.size() - 1; j++) {
        ll L = P[j], R = P[j + 1];
        ll C = 1;
        for (int k = 3; k <= 60; k++) C = (C * (sol(L, k) % MOD)) % MOD;
        
        ll SL = sqr(L), SR = sqr(R - 1);
        ll sub = 0;
        if (SL == SR) {
            sub = (SL % MOD) * sum_i(L, R - 1) % MOD;
        } else {
            sub = (sub + (SL % MOD) * sum_i(L, (int128)(SL + 1) * (SL + 1) - 1)) % MOD;
            sub = (sub + sum(SR - 1) - sum(SL) + MOD) % MOD;
            sub = (sub + (SR % MOD) * sum_i((int128)SR * SR, R - 1)) % MOD;
        }
        ans = (ans + C * sub) % MOD;
    }
    cout << ans << endl;
    return 0;
}
0