結果

問題 No.3505 Sum of Prod of Root
コンテスト
ユーザー tassei903
提出日時 2026-04-18 00:47:10
言語 PyPy3
(7.3.17)
コンパイル:
pypy3 -mpy_compile _filename_
実行:
pypy3 _filename_
結果
AC  
実行時間 1,751 ms / 2,000 ms
コード長 2,691 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 186 ms
コンパイル使用メモリ 85,888 KB
実行使用メモリ 293,496 KB
最終ジャッジ日時 2026-04-18 00:47:20
合計ジャッジ時間 7,343 ms
ジャッジサーバーID
(参考情報)
judge1_0 / judge2_0
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 13
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys
ni = lambda :int(input())
na = lambda :list(map(int,input().split()))
yes = lambda :print("yes");Yes = lambda :print("Yes");YES = lambda : print("YES")
no = lambda :print("no");No = lambda :print("No");NO = lambda : print("NO")
#####################################################################
"""

j ** k <= i < (j + 1) ** k

sum_i

10 ** 18 < 2 ** 64

1 2 3 4 5 6 7 8 9
1 1 1 2 2 2 2 2 3
1 1 1 1 1 1 1 2 2

O(n ** 0.5) は各区間に対して,


sum_{i <= r} i * floor(i ** 0.5) を O (1) で求められればいい

j * (j ** 2 + ... (j + 1) ** 2 - 1)

"""
mod = 998244353

def kth_root_integer(a: int, k: int) -> int:
    if a <= 1 or k == 1: return a
    if 64 <= k: return 1
    def check(n: int) -> bool:
        x = 1; m = n
        p = k
        while p:
            if p & 1: x *= m
            p >>= 1
            m *= m
        return x <= a
    n = int(pow(a, 1 / k))
    while not check(n): n -= 1
    while check(n + 1): n += 1
    return n

def f(n):
    res = 1
    for k in range(1, 64):
        x = kth_root_integer(n, k)
        if x == 1:
            break
        res *= x
    return res

def f_naive(r):
    ans = 0
    for i in range(1, r + 1):
        ans += i * kth_root_integer(i, 2) % mod
        ans %= mod
    return ans
m20 = pow(20, mod-2, mod)
def g(i):# f(1) + ... f(i)
    i2 = i * i % mod
    return (8 * i2 % mod * i2 % mod * i % mod + 35 * i2 % mod * i2 % mod + 50 * i2 % mod * i % mod + 25 * i2 + 2 * i) * m20 % mod

m2 = pow(2, mod-2, mod)
def f_solve(r):
    # j ** 2 <= i < (j + 1) ** 2
    J = kth_root_integer(r, 2)
    # J ** 2 <= r
    # f(j) = 2 * j ** 4 + 3 * j ** 3 + j ** 2
    # J * (J ** 2 + ... + r)
    # print("!", J, g(J-1), J * (J ** 2 + r) * (r - J ** 2 + 1) //2 )
    x = J % mod
    x2 = x * x % mod
    return (g(J - 1) + x * (x2 + r) % mod * (r - x2 + 1) % mod * m2 % mod ) % mod


from collections import defaultdict

def solve(n):
    d = defaultdict(list)
    for k in range(3, 65):
        i = 1
        while i ** k <= n:
            d[i ** k].append((k, i))
            i += 1
    
    b = [0] * 65
    last = 1
    ans = 0
    for x in sorted(d):
        z = 1
        for j in range(3, 64):
            z *= b[j]
            z %= mod
        ans += (f_solve(x-1) - f_solve(last-1)) * z % mod
        ans %= mod
        for k, i in d[x]:
            b[k] = i
        last = x
    z = 1
    for j in range(3, 64):
        z *= b[j]
        z %= mod
    ans += (f_solve(n) - f_solve(last-1)) * z
    ans %= mod
    return ans

def naive(n):
    ans = 0
    for i in range(1, n + 1):
        ans += f(i)
    return ans

n = ni()
print(solve(n))


# for i in range(1, 30):
#     print(i, naive(i), solve(i))
0