結果

問題 No.3505 Sum of Prod of Root
コンテスト
ユーザー rumblycascade7
提出日時 2026-04-18 05:30:35
言語 Python3
(3.14.3 + numpy 2.4.4 + scipy 1.17.1)
コンパイル:
python3 -mpy_compile _filename_
実行:
python3 _filename_
結果
TLE  
実行時間 -
コード長 2,091 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 412 ms
コンパイル使用メモリ 20,564 KB
実行使用メモリ 17,140 KB
最終ジャッジ日時 2026-04-18 05:30:57
合計ジャッジ時間 7,139 ms
ジャッジサーバーID
(参考情報)
judge3_0 / judge2_0
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 1
other AC * 4 TLE * 1 -- * 8
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys
import math

input = sys.stdin.readline
MOD = 998244353

n = int(input())

inv2 = (MOD + 1) // 2
inv6 = pow(6, MOD - 2, MOD)
inv30 = pow(30, MOD - 2, MOD)

def s1(x):
    return (x % MOD) * ((x + 1) % MOD) % MOD * inv2 % MOD

def s2(x):
    return (x % MOD) * ((x + 1) % MOD) % MOD * ((2 * x + 1) % MOD) % MOD * inv6 % MOD

def s3(x):
    t = s1(x)
    return t * t % MOD

def s4(x):
    a = x % MOD
    b = (x + 1) % MOD
    c = (2 * x + 1) % MOD
    d = (3 * a * a + 3 * a - 1) % MOD
    return a * b % MOD * c % MOD * d % MOD * inv30 % MOD

def pref(m):
    if m <= 0:
        return 0

    rt = math.isqrt(m)
    t = rt - 1

    out = (2 * s4(t) + 3 * s3(t) + s2(t)) % MOD

    l = rt * rt
    cnt = m - l + 1
    seg = ((l + m) % MOD) * (cnt % MOD) % MOD * inv2 % MOD

    out += (rt % MOD) * seg
    return out % MOD

max_a = int(round(n ** 0.25)) + 5
while max_a ** 4 <= n:
    max_a += 1
while (max_a - 1) ** 4 > n:
    max_a -= 1

inv = [0] * (max_a + 3)
inv[1] = 1
for i in range(2, max_a + 3):
    inv[i] = MOD - (MOD // i) * inv[MOD % i] % MOD

ev = {}

for k in range(4, 61):
    a = 2
    while True:
        v = a ** k
        if v > n:
            break

        cur = ev.get(v, 1)
        cur = cur * a % MOD
        cur = cur * inv[a - 1] % MOD
        ev[v] = cur

        a += 1

pts = sorted(ev.items())
m = len(pts)
ptr = 0

cube = int(round(n ** (1.0 / 3.0))) + 5
while cube ** 3 <= n:
    cube += 1
while (cube - 1) ** 3 > n:
    cube -= 1
cube -= 1

ans = 0
more = 1

for a in range(1, cube + 1):
    l = a * a * a
    if a == cube:
        r = n
    else:
        b = a + 1
        r = b * b * b - 1

    cur = l
    base = a % MOD

    while ptr < m and pts[ptr][0] <= r:
        x, mul = pts[ptr]

        if cur <= x - 1:
            part = (pref(x - 1) - pref(cur - 1)) % MOD
            ans += base * more % MOD * part
            ans %= MOD

        more = more * mul % MOD
        cur = x
        ptr += 1

    if cur <= r:
        part = (pref(r) - pref(cur - 1)) % MOD
        ans += base * more % MOD * part
        ans %= MOD

print(ans % MOD)
0