結果
| 問題 | No.3505 Sum of Prod of Root |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2026-04-18 18:06:04 |
| 言語 | C++23 (gcc 15.2.0 + boost 1.89.0) |
| 結果 |
AC
|
| 実行時間 | 195 ms / 2,000 ms |
| コード長 | 4,602 bytes |
| 記録 | |
| コンパイル時間 | 3,263 ms |
| コンパイル使用メモリ | 344,012 KB |
| 実行使用メモリ | 27,392 KB |
| 最終ジャッジ日時 | 2026-04-18 18:06:09 |
| 合計ジャッジ時間 | 4,683 ms |
|
ジャッジサーバーID (参考情報) |
judge3_0 / judge1_0 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 13 |
ソースコード
#include <bits/stdc++.h>
using namespace std;
using int64 = long long;
using u64 = unsigned long long;
using u128 = __uint128_t;
static const int64 MOD = 998244353;
int64 mod_pow(int64 a, int64 e) {
int64 r = 1 % MOD;
a %= MOD;
while (e > 0) {
if (e & 1) r = (int64)((u128)r * a % MOD);
a = (int64)((u128)a * a % MOD);
e >>= 1;
}
return r;
}
inline int64 mul_mod(int64 a, int64 b) {
return (int64)((u128)a * b % MOD);
}
// a^e <= lim ?
bool power_leq(u64 a, int e, u64 lim) {
u128 v = 1;
for (int i = 0; i < e; i++) {
if (v > (u128)lim / a) return false;
v *= a;
}
return true;
}
// exact a^e (caller guarantees <= 1e18)
u64 power_exact(u64 a, int e) {
u128 v = 1;
for (int i = 0; i < e; i++) v *= a;
return (u64)v;
}
// floor(n^(1/k)), for n<=1e18 and k>=3 (root <= 1e6)
u64 kth_root_floor(u64 n, int k) {
u64 lo = 1, hi = min<u64>(1000000ULL, n);
while (lo < hi) {
u64 mid = (lo + hi + 1) >> 1;
if (power_leq(mid, k, n)) lo = mid;
else hi = mid - 1;
}
return lo;
}
u64 isqrt_u64(u64 x) {
u64 r = (u64)sqrtl((long double)x);
while ((u128)(r + 1) * (r + 1) <= x) ++r;
while ((u128)r * r > x) --r;
return r;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
u64 N;
cin >> N;
// Gather events: (x = m^k, multiplier changes by m/(m-1)), k>=3, m>=2
vector<pair<u64, int>> events;
events.reserve(1100000);
for (int k = 3; k <= 60; k++) {
if (!power_leq(2, k, N)) break; // no m>=2 possible afterwards
u64 lim = kth_root_floor(N, k);
for (u64 m = 2; m <= lim; m++) {
u64 x = power_exact(m, k);
events.push_back({x, (int)m});
}
}
sort(events.begin(), events.end(),
[](const auto& a, const auto& b) {
if (a.first != b.first) return a.first < b.first;
return a.second < b.second;
});
// inverses up to max m
int max_m = 1;
for (auto &e : events) max_m = max(max_m, e.second);
vector<int64> inv(max_m + 1, 0);
inv[1] = 1;
for (int i = 2; i <= max_m; i++) {
inv[i] = (MOD - mul_mod(MOD / i, inv[MOD % i])) % MOD;
}
const int64 INV2 = mod_pow(2, MOD - 2);
const int64 INV6 = mod_pow(6, MOD - 2);
const int64 INV30 = mod_pow(30, MOD - 2);
auto sum1 = [&](u64 n) -> int64 {
int64 a = (int64)(n % MOD);
int64 b = (int64)((n + 1) % MOD);
return mul_mod(mul_mod(a, b), INV2);
};
auto sum2 = [&](u64 n) -> int64 {
int64 a = (int64)(n % MOD);
int64 b = (int64)((n + 1) % MOD);
int64 c = (int64)((2 * (n % MOD) + 1) % MOD);
return mul_mod(mul_mod(mul_mod(a, b), c), INV6);
};
auto sum3 = [&](u64 n) -> int64 {
int64 s = sum1(n);
return mul_mod(s, s);
};
auto sum4 = [&](u64 n) -> int64 {
int64 a = (int64)(n % MOD);
int64 b = (int64)((n + 1) % MOD);
int64 c = (int64)((2 * (n % MOD) + 1) % MOD);
int64 n2 = mul_mod(a, a);
int64 d = ( (3 * n2) % MOD + (3 * a) % MOD - 1 + MOD ) % MOD;
return mul_mod(mul_mod(mul_mod(mul_mod(a, b), c), d), INV30);
};
auto G = [&](u64 x) -> int64 {
if (x == 0) return 0;
u64 s = isqrt_u64(x);
u64 u = s - 1;
int64 part1 = 0;
part1 = (part1 + 2 * sum4(u)) % MOD;
part1 = (part1 + 3 * sum3(u)) % MOD;
part1 = (part1 + sum2(u)) % MOD;
u64 sq = (u64)((u128)s * s);
int64 tail = (sum1(x) - sum1(sq - 1) + MOD) % MOD;
int64 part2 = mul_mod((int64)(s % MOD), tail);
return (part1 + part2) % MOD;
};
auto range_sum = [&](u64 L, u64 R) -> int64 {
if (L > R) return 0;
return (G(R) - G(L - 1) + MOD) % MOD;
};
int64 ans = 0;
int64 T = 1; // product_{k>=3} floor(i^(1/k)) on current interval
u64 prev = 1;
size_t i = 0;
while (i < events.size()) {
u64 x = events[i].first;
if (prev <= x - 1) {
int64 seg = range_sum(prev, x - 1);
ans = (ans + mul_mod(T, seg)) % MOD;
}
// apply all events at x
while (i < events.size() && events[i].first == x) {
int m = events[i].second;
T = mul_mod(T, m);
T = mul_mod(T, inv[m - 1]);
i++;
}
prev = x;
}
if (prev <= N) {
int64 seg = range_sum(prev, N);
ans = (ans + mul_mod(T, seg)) % MOD;
}
cout << ans % MOD << '\n';
return 0;
}