結果
| 問題 | No.3505 Sum of Prod of Root |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2026-04-18 04:45:11 |
| 言語 | C++17 (gcc 15.2.0 + boost 1.89.0) |
| 結果 |
AC
|
| 実行時間 | 279 ms / 2,000 ms |
| コード長 | 4,458 bytes |
| 記録 | |
| コンパイル時間 | 1,364 ms |
| コンパイル使用メモリ | 225,752 KB |
| 実行使用メモリ | 27,552 KB |
| 最終ジャッジ日時 | 2026-04-18 04:45:15 |
| 合計ジャッジ時間 | 3,212 ms |
|
ジャッジサーバーID (参考情報) |
judge1_0 / judge3_0 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 13 |
ソースコード
#include <bits/stdc++.h>
using namespace std;
using int64 = long long;
using i128 = __int128_t;
static const long long MOD = 998244353;
static const long long INV2 = (MOD + 1) / 2;
static const long long INV6 = 166374059; // 6^{-1} mod MOD
static const long long INV30 = 432572553; // 30^{-1} mod MOD
long long mod_pow(long long a, long long e) {
long long r = 1 % MOD;
while (e > 0) {
if (e & 1) r = (long long)((__int128)r * a % MOD);
a = (long long)((__int128)a * a % MOD);
e >>= 1;
}
return r;
}
long long isqrt_ll(long long n) {
long long x = sqrtl((long double)n);
while ((i128)(x + 1) * (x + 1) <= n) ++x;
while ((i128)x * x > n) --x;
return x;
}
long long icbrt_ll(long long n) {
long long x = cbrtl((long double)n);
while ((i128)(x + 1) * (x + 1) * (x + 1) <= n) ++x;
while ((i128)x * x * x > n) --x;
return x;
}
long long sum1(long long n) { // 1^1 + ... + n
n %= MOD;
return (long long)((__int128)n * ((n + 1) % MOD) % MOD * INV2 % MOD);
}
long long sum2(long long n) { // 1^2 + ... + n^2
long long a = n % MOD;
long long b = (n + 1) % MOD;
long long c = (2 * (n % MOD) + 1) % MOD;
return (long long)((__int128)a * b % MOD * c % MOD * INV6 % MOD);
}
long long sum3(long long n) { // 1^3 + ... + n^3
long long s = sum1(n);
return (long long)((__int128)s * s % MOD);
}
long long sum4(long long n) { // 1^4 + ... + n^4
long long a = n % MOD;
long long b = (n + 1) % MOD;
long long c = (2 * (n % MOD) + 1) % MOD;
long long d = (3 * ( (__int128)n * n % MOD ) % MOD + 3 * (n % MOD) - 1) % MOD;
if (d < 0) d += MOD;
return (long long)((__int128)a * b % MOD * c % MOD * d % MOD * INV30 % MOD);
}
long long range_sum(long long l, long long r) { // l + ... + r
if (l > r) return 0;
long long len = (r - l + 1) % MOD;
long long s = ((l % MOD) + (r % MOD)) % MOD;
return (long long)((__int128)s * len % MOD * INV2 % MOD);
}
// full blocks: sum_{t=1}^{m} (2t^4 + 3t^3 + t^2)
long long full_blocks(long long m) {
if (m <= 0) return 0;
long long s2 = sum2(m);
long long s3 = sum3(m);
long long s4 = sum4(m);
long long res = 0;
res = (res + 2LL * s4) % MOD;
res = (res + 3LL * s3) % MOD;
res = (res + s2) % MOD;
return res;
}
// F(n) = sum_{i=1}^{n} i * floor(sqrt(i))
long long calcF(long long n) {
if (n <= 0) return 0;
long long m = isqrt_ll(n);
long long res = full_blocks(m - 1);
long long l = m * m;
long long tail = range_sum(l, n);
res = (res + (__int128)(m % MOD) * tail) % MOD;
return res;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
long long N;
cin >> N;
// 3乗以上の完全冪 x = a^k (k>=3) を全部イベントとして集める
vector<pair<long long, int>> events; // (x, a)
long long lim = icbrt_ll(N);
for (long long a = 2; a <= lim; ++a) {
i128 p = (i128)a * a * a; // a^3
while (p <= N) {
events.push_back({(long long)p, (int)a});
p *= a;
}
}
sort(events.begin(), events.end());
// inv[i] を 1..lim まで前計算
vector<long long> inv(max(2LL, lim + 1), 1);
for (long long i = 2; i <= lim; ++i) {
inv[i] = MOD - (long long)((__int128)(MOD / i) * inv[MOD % i] % MOD);
}
long long ans = 0;
long long curL = 1;
long long R = 1; // prod_{k>=3} floor(i^{1/k}) on current interval
int m = (int)events.size();
int idx = 0;
while (idx < m) {
long long x = events[idx].first;
// [curL, x-1] では R は一定
if (curL <= x - 1) {
long long seg = (calcF(x - 1) - calcF(curL - 1)) % MOD;
if (seg < 0) seg += MOD;
ans = (ans + (__int128)R * seg) % MOD;
}
// x で起こる全イベントを適用
while (idx < m && events[idx].first == x) {
int a = events[idx].second;
// factor: (a-1) -> a なので * a / (a-1)
long long mul = (long long)((__int128)(a % MOD) * inv[a - 1] % MOD);
R = (long long)((__int128)R * mul % MOD);
++idx;
}
curL = x;
}
// 最後の区間 [curL, N]
if (curL <= N) {
long long seg = (calcF(N) - calcF(curL - 1)) % MOD;
if (seg < 0) seg += MOD;
ans = (ans + (__int128)R * seg) % MOD;
}
cout << ans % MOD << '\n';
return 0;
}