結果
| 問題 | No.3505 Sum of Prod of Root |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2026-04-17 21:56:19 |
| 言語 | C++23 (gcc 15.2.0 + boost 1.89.0) |
| 結果 |
AC
|
| 実行時間 | 272 ms / 2,000 ms |
| コード長 | 3,903 bytes |
| 記録 | |
| コンパイル時間 | 1,967 ms |
| コンパイル使用メモリ | 193,848 KB |
| 実行使用メモリ | 43,900 KB |
| 最終ジャッジ日時 | 2026-04-17 21:56:28 |
| 合計ジャッジ時間 | 4,136 ms |
|
ジャッジサーバーID (参考情報) |
judge2_1 / judge3_0 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 13 |
ソースコード
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <iostream>
#include <utility>
#include <vector>
using namespace std;
using int64 = long long;
using i128 = __int128_t;
static constexpr int64 MOD = 998244353;
static constexpr int64 INV2 = (MOD + 1) / 2;
static constexpr int64 INV6 = 166374059; // 6^{-1} mod MOD
static constexpr int64 INV30 = 432572553; // 30^{-1} mod MOD
int64 mul_mod(int64 a, int64 b) {
return (i128)a * b % MOD;
}
int64 add_mod(int64 a, int64 b) {
a += b;
if (a >= MOD) a -= MOD;
return a;
}
int64 sub_mod(int64 a, int64 b) {
a -= b;
if (a < 0) a += MOD;
return a;
}
bool le_pow(int64 base, int k, int64 limit) {
i128 cur = 1;
for (int i = 0; i < k; ++i) {
cur *= base;
if (cur > limit) return false;
}
return true;
}
int64 kth_root_floor(int64 n, int k) {
if (k == 1 || n <= 1) return n;
long double x = pow((long double)n, 1.0L / k);
int64 r = (int64)x;
if (r < 1) r = 1;
while (le_pow(r + 1, k, n)) ++r;
while (!le_pow(r, k, n)) --r;
return r;
}
int64 isqrt_floor(int64 n) {
int64 r = (int64)sqrt((long double)n);
while ((i128)(r + 1) * (r + 1) <= n) ++r;
while ((i128)r * r > n) --r;
return r;
}
int64 tri(int64 n) {
if (n <= 0) return 0;
return mul_mod(mul_mod(n % MOD, (n + 1) % MOD), INV2);
}
int64 sum2(int64 n) {
if (n <= 0) return 0;
int64 a = n % MOD;
int64 b = (n + 1) % MOD;
int64 c = (2 * (n % MOD) + 1) % MOD;
return mul_mod(mul_mod(a, b), mul_mod(c, INV6));
}
int64 sum3(int64 n) {
if (n <= 0) return 0;
int64 t = tri(n);
return mul_mod(t, t);
}
int64 sum4(int64 n) {
if (n <= 0) return 0;
int64 a = n % MOD;
int64 b = (n + 1) % MOD;
int64 c = (2 * (n % MOD) + 1) % MOD;
int64 d = (3 * mul_mod(a, a) + 3 * a - 1) % MOD;
if (d < 0) d += MOD;
return mul_mod(mul_mod(mul_mod(a, b), c), mul_mod(d, INV30));
}
// W(x) = sum_{n=1}^{x} n * floor(sqrt(n))
int64 prefix_weight(int64 x) {
if (x <= 0) return 0;
int64 s = isqrt_floor(x);
int64 m = s - 1;
int64 full = 0;
full = add_mod(full, mul_mod(2, sum4(m)));
full = add_mod(full, mul_mod(3, sum3(m)));
full = add_mod(full, sum2(m));
int64 tail = mul_mod(s % MOD, sub_mod(tri(x), tri((int64)((i128)s * s - 1))));
return add_mod(full, tail);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int64 N;
cin >> N;
int64 max_base = kth_root_floor(N, 3);
vector<int64> inv(max<int64>(max_base + 1, 2), 1);
for (int64 i = 2; i <= max_base; ++i) {
inv[i] = MOD - mul_mod(MOD / i, inv[MOD % i]);
}
vector<pair<int64, int64>> events;
for (int k = 3; k <= 60; ++k) {
int64 lim = kth_root_floor(N, k);
for (int64 b = 2; b <= lim; ++b) {
i128 p = 1;
for (int i = 0; i < k; ++i) p *= b;
events.push_back({(int64)p, mul_mod(b % MOD, inv[b - 1])});
}
}
sort(events.begin(), events.end());
vector<pair<int64, int64>> merged;
merged.reserve(events.size());
for (auto [x, mul] : events) {
if (!merged.empty() && merged.back().first == x) {
merged.back().second = mul_mod(merged.back().second, mul);
} else {
merged.push_back({x, mul});
}
}
int64 ans = 0;
int64 cur = 1;
int64 h = 1; // current value of H(n) on the active interval
for (auto [x, mul] : merged) {
if (cur <= x - 1) {
int64 segment = sub_mod(prefix_weight(x - 1), prefix_weight(cur - 1));
ans = add_mod(ans, mul_mod(h, segment));
}
h = mul_mod(h, mul);
cur = x;
}
if (cur <= N) {
int64 segment = sub_mod(prefix_weight(N), prefix_weight(cur - 1));
ans = add_mod(ans, mul_mod(h, segment));
}
cout << ans << '\n';
return 0;
}