結果
| 問題 | No.3505 Sum of Prod of Root |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2026-04-18 00:47:10 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
AC
|
| 実行時間 | 1,751 ms / 2,000 ms |
| コード長 | 2,691 bytes |
| 記録 | |
| コンパイル時間 | 186 ms |
| コンパイル使用メモリ | 85,888 KB |
| 実行使用メモリ | 293,496 KB |
| 最終ジャッジ日時 | 2026-04-18 00:47:20 |
| 合計ジャッジ時間 | 7,343 ms |
|
ジャッジサーバーID (参考情報) |
judge1_0 / judge2_0 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 13 |
ソースコード
import sys
ni = lambda :int(input())
na = lambda :list(map(int,input().split()))
yes = lambda :print("yes");Yes = lambda :print("Yes");YES = lambda : print("YES")
no = lambda :print("no");No = lambda :print("No");NO = lambda : print("NO")
#####################################################################
"""
j ** k <= i < (j + 1) ** k
sum_i
10 ** 18 < 2 ** 64
1 2 3 4 5 6 7 8 9
1 1 1 2 2 2 2 2 3
1 1 1 1 1 1 1 2 2
O(n ** 0.5) は各区間に対して,
sum_{i <= r} i * floor(i ** 0.5) を O (1) で求められればいい
j * (j ** 2 + ... (j + 1) ** 2 - 1)
"""
mod = 998244353
def kth_root_integer(a: int, k: int) -> int:
if a <= 1 or k == 1: return a
if 64 <= k: return 1
def check(n: int) -> bool:
x = 1; m = n
p = k
while p:
if p & 1: x *= m
p >>= 1
m *= m
return x <= a
n = int(pow(a, 1 / k))
while not check(n): n -= 1
while check(n + 1): n += 1
return n
def f(n):
res = 1
for k in range(1, 64):
x = kth_root_integer(n, k)
if x == 1:
break
res *= x
return res
def f_naive(r):
ans = 0
for i in range(1, r + 1):
ans += i * kth_root_integer(i, 2) % mod
ans %= mod
return ans
m20 = pow(20, mod-2, mod)
def g(i):# f(1) + ... f(i)
i2 = i * i % mod
return (8 * i2 % mod * i2 % mod * i % mod + 35 * i2 % mod * i2 % mod + 50 * i2 % mod * i % mod + 25 * i2 + 2 * i) * m20 % mod
m2 = pow(2, mod-2, mod)
def f_solve(r):
# j ** 2 <= i < (j + 1) ** 2
J = kth_root_integer(r, 2)
# J ** 2 <= r
# f(j) = 2 * j ** 4 + 3 * j ** 3 + j ** 2
# J * (J ** 2 + ... + r)
# print("!", J, g(J-1), J * (J ** 2 + r) * (r - J ** 2 + 1) //2 )
x = J % mod
x2 = x * x % mod
return (g(J - 1) + x * (x2 + r) % mod * (r - x2 + 1) % mod * m2 % mod ) % mod
from collections import defaultdict
def solve(n):
d = defaultdict(list)
for k in range(3, 65):
i = 1
while i ** k <= n:
d[i ** k].append((k, i))
i += 1
b = [0] * 65
last = 1
ans = 0
for x in sorted(d):
z = 1
for j in range(3, 64):
z *= b[j]
z %= mod
ans += (f_solve(x-1) - f_solve(last-1)) * z % mod
ans %= mod
for k, i in d[x]:
b[k] = i
last = x
z = 1
for j in range(3, 64):
z *= b[j]
z %= mod
ans += (f_solve(n) - f_solve(last-1)) * z
ans %= mod
return ans
def naive(n):
ans = 0
for i in range(1, n + 1):
ans += f(i)
return ans
n = ni()
print(solve(n))
# for i in range(1, 30):
# print(i, naive(i), solve(i))