結果

問題 No.3505 Sum of Prod of Root
コンテスト
ユーザー ガルム
提出日時 2026-04-18 14:10:18
言語 Rust
(1.94.0 + proconio + num + itertools)
コンパイル:
/usr/bin/rustc_custom
実行:
./target/release/main
結果
AC  
実行時間 157 ms / 2,000 ms
コード長 3,555 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 1,492 ms
コンパイル使用メモリ 201,844 KB
実行使用メモリ 26,240 KB
最終ジャッジ日時 2026-04-18 14:10:39
合計ジャッジ時間 4,063 ms
ジャッジサーバーID
(参考情報)
judge1_1 / judge2_0
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 13
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

use std::io::{self, Read};

const MOD: u64 = 998244353;

fn mul(a: u64, b: u64) -> u64 {
    ((a as u128 * b as u128) % MOD as u128) as u64
}

fn mpow(mut a: u64, mut e: u64) -> u64 {
    let mut r = 1u64;
    while e > 0 {
        if e & 1 == 1 {
            r = mul(r, a);
        }
        a = mul(a, a);
        e >>= 1;
    }
    r
}

fn pow_lim(a: u64, k: u32, lim: u64) -> Option<u64> {
    let mut r = 1u128;
    for _ in 0..k {
        r *= a as u128;
        if r > lim as u128 {
            return None;
        }
    }
    Some(r as u64)
}

fn kth_root(n: u64, k: u32) -> u64 {
    let mut ok = 1u64;
    let mut ng = 2u64;
    while pow_lim(ng, k, n).is_some() {
        ok = ng;
        ng *= 2;
    }
    while ng - ok > 1 {
        let mid = (ok + ng) / 2;
        if pow_lim(mid, k, n).is_some() {
            ok = mid;
        } else {
            ng = mid;
        }
    }
    ok
}

fn isqrt(n: u64) -> u64 {
    let mut x = (n as f64).sqrt() as u64;
    while (x + 1) as u128 * (x + 1) as u128 <= n as u128 {
        x += 1;
    }
    while x as u128 * x as u128 > n as u128 {
        x -= 1;
    }
    x
}

fn sum2(t: u64, inv6: u64) -> u64 {
    let a = t % MOD;
    let b = (t + 1) % MOD;
    let c = ((2 * (t % MOD)) + 1) % MOD;
    mul(mul(mul(a, b), c), inv6)
}

fn sum3(t: u64, inv2: u64) -> u64 {
    let a = t % MOD;
    let b = (t + 1) % MOD;
    let x = mul(mul(a, b), inv2);
    mul(x, x)
}

fn sum4(t: u64, inv30: u64) -> u64 {
    let a = t % MOD;
    let b = (t + 1) % MOD;
    let c = ((2 * (t % MOD)) + 1) % MOD;
    let aa = mul(a, a);
    let d = (3 * aa % MOD + 3 * a % MOD + MOD - 1) % MOD;
    mul(mul(mul(mul(a, b), c), d), inv30)
}

fn sum_i_sqrt(n: u64, inv2: u64, inv6: u64, inv30: u64) -> u64 {
    if n == 0 {
        return 0;
    }

    let m = isqrt(n);
    let t = m - 1;
    let full = (2 * sum4(t, inv30) % MOD + 3 * sum3(t, inv2) % MOD + sum2(t, inv6)) % MOD;

    let l = m * m;
    let cnt = n - l + 1;
    let s = mul(mul((l % MOD + n % MOD) % MOD, cnt % MOD), inv2);
    (full + mul(m % MOD, s)) % MOD
}

fn range_sum(l: u64, r: u64, inv2: u64, inv6: u64, inv30: u64) -> u64 {
    (sum_i_sqrt(r, inv2, inv6, inv30) + MOD - sum_i_sqrt(l - 1, inv2, inv6, inv30)) % MOD
}

fn main() {
    let mut s = String::new();
    io::stdin().read_to_string(&mut s).unwrap();
    let n: u64 = s.trim().parse().unwrap();

    let inv2 = mpow(2, MOD - 2);
    let inv6 = mpow(6, MOD - 2);
    let inv30 = mpow(30, MOD - 2);

    let mx = kth_root(n, 3) as usize + 2;
    let mut inv = vec![0u64; mx.max(2) + 1];
    inv[1] = 1;
    for i in 2..inv.len() {
        inv[i] = MOD - mul(MOD / i as u64, inv[(MOD % i as u64) as usize]);
    }

    let mut ev = Vec::<(u64, usize)>::new();
    let mut k = 3u32;
    while pow_lim(2, k, n).is_some() {
        let mut x = 2u64;
        while let Some(p) = pow_lim(x, k, n) {
            ev.push((p, x as usize));
            x += 1;
        }
        k += 1;
    }
    ev.sort_unstable();

    let mut ans = 0u64;
    let mut cur = 1u64;
    let mut prod = 1u64;
    let mut i = 0usize;

    while i < ev.len() {
        let p = ev[i].0;
        if cur < p {
            ans = (ans + mul(prod, range_sum(cur, p - 1, inv2, inv6, inv30))) % MOD;
        }

        while i < ev.len() && ev[i].0 == p {
            let x = ev[i].1;
            prod = mul(mul(prod, x as u64), inv[x - 1]);
            i += 1;
        }
        cur = p;
    }

    if cur <= n {
        ans = (ans + mul(prod, range_sum(cur, n, inv2, inv6, inv30))) % MOD;
    }

    println!("{}", ans);
}
0