結果
| 問題 | No.3505 Sum of Prod of Root |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2026-04-18 01:50:34 |
| 言語 | C++23 (gcc 15.2.0 + boost 1.89.0) |
| 結果 |
AC
|
| 実行時間 | 223 ms / 2,000 ms |
| コード長 | 3,454 bytes |
| 記録 | |
| コンパイル時間 | 763 ms |
| コンパイル使用メモリ | 76,288 KB |
| 実行使用メモリ | 35,840 KB |
| 最終ジャッジ日時 | 2026-04-18 01:50:56 |
| 合計ジャッジ時間 | 2,319 ms |
|
ジャッジサーバーID (参考情報) |
judge1_1 / judge2_1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 13 |
ソースコード
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <math.h>
typedef long long ll;
typedef __int128 i128;
#define MOD 998244353LL
#define MAXA 1000005
#define MAXE 1100000
typedef struct {
ll x;
ll mul;
} Event;
static Event ev[MAXE];
static ll inv[MAXA];
static int ecnt = 0;
static inline ll norm(ll x) {
x %= MOD;
if (x < 0) x += MOD;
return x;
}
static ll mod_pow(ll a, ll e) {
ll r = 1;
while (e > 0) {
if (e & 1) r = (ll)((i128)r * a % MOD);
a = (ll)((i128)a * a % MOD);
e >>= 1;
}
return r;
}
static int cmp_event(const void *pa, const void *pb) {
ll a = ((const Event *)pa)->x;
ll b = ((const Event *)pb)->x;
return (a < b) ? -1 : (a > b);
}
static inline ll sum1(ll n) {
n %= MOD;
return (ll)((i128)n * ((n + 1) % MOD) % MOD * ((MOD + 1) / 2) % MOD);
}
static inline ll sum2(ll n) {
static const ll INV6 = 166374059LL;
ll a = n % MOD;
ll b = (n + 1) % MOD;
ll c = (2 * a + 1) % MOD;
return (ll)((i128)a * b % MOD * c % MOD * INV6 % MOD);
}
static inline ll sum3(ll n) {
static const ll INV4 = 748683265LL;
ll a = n % MOD;
ll b = (n + 1) % MOD;
ll t = (ll)((i128)a * b % MOD);
return (ll)((i128)t * t % MOD * INV4 % MOD);
}
static inline ll sum4(ll n) {
static const ll INV30 = 432572553LL;
ll a = n % MOD;
ll b = (n + 1) % MOD;
ll c = (2 * a + 1) % MOD;
ll d = norm((ll)((3 * (i128)a * a) % MOD) + 3 * a - 1);
return (ll)((i128)a * b % MOD * c % MOD * d % MOD * INV30 % MOD);
}
static inline ll isqrt_ll(ll n) {
long double x = sqrtl((long double)n);
ll r = (ll)x;
while ((i128)(r + 1) * (r + 1) <= n) r++;
while ((i128)r * r > n) r--;
return r;
}
/*
F(n) = sum_{i=1}^n i * floor(sqrt(i))
*/
static ll pref(ll n) {
if (n <= 0) return 0;
ll m = isqrt_ll(n);
ll k = m - 1;
ll full = 0;
if (k >= 1) {
ll s2 = sum2(k);
ll s3 = sum3(k);
ll s4 = sum4(k);
full = norm(2 * s4 + 3 * s3 + s2);
}
ll mm = m % MOD;
ll left = (ll)((i128)m * m - 1);
ll tail = norm(sum1(n) - sum1(left));
tail = (ll)((i128)tail * mm % MOD);
return norm(full + tail);
}
int main(void) {
ll N;
scanf("%lld", &N);
inv[1] = 1;
for (int i = 2; i < MAXA; i++) {
inv[i] = MOD - (ll)((i128)(MOD / i) * inv[MOD % i] % MOD);
}
for (int k = 3; k <= 60; k++) {
for (ll a = 2;; a++) {
i128 p = 1;
for (int t = 0; t < k; t++) {
p *= a;
if (p > N) break;
}
if (p > N) break;
ll x = (ll)p;
ll mul = (ll)((i128)(a % MOD) * inv[a - 1] % MOD);
ev[ecnt].x = x;
ev[ecnt].mul = mul;
ecnt++;
}
}
qsort(ev, ecnt, sizeof(Event), cmp_event);
ll ans = 0;
ll H = 1;
ll prev = 1;
int i = 0;
while (i < ecnt) {
ll x = ev[i].x;
if (prev <= x - 1) {
ll seg = norm(pref(x - 1) - pref(prev - 1));
ans = norm(ans + (ll)((i128)H * seg % MOD));
}
while (i < ecnt && ev[i].x == x) {
H = (ll)((i128)H * ev[i].mul % MOD);
i++;
}
prev = x;
}
if (prev <= N) {
ll seg = norm(pref(N) - pref(prev - 1));
ans = norm(ans + (ll)((i128)H * seg % MOD));
}
printf("%lld\n", ans);
return 0;
}