結果
| 問題 | No.3505 Sum of Prod of Root |
| コンテスト | |
| ユーザー |
ID 21712
|
| 提出日時 | 2026-04-20 20:16:56 |
| 言語 | Go (1.26.1) |
| 結果 |
RE
|
| 実行時間 | - |
| コード長 | 6,005 bytes |
| 記録 | |
| コンパイル時間 | 12,921 ms |
| コンパイル使用メモリ | 282,780 KB |
| 実行使用メモリ | 65,024 KB |
| 最終ジャッジ日時 | 2026-04-20 20:17:23 |
| 合計ジャッジ時間 | 21,084 ms |
|
ジャッジサーバーID (参考情報) |
judge1_0 / judge3_1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 6 RE * 7 |
ソースコード
package main
import . "fmt"
import . "sort"
// import . "math"
import . "math/big"
import "math/rand"
const DEBUG = true
const M = 998244353
// func max(a,b int) int { if a>b { return a; } else { return b; } }; func min(a,b int) int { if a<b { return a; } else { return b; } }
func init() {
// N の大きいところで不一致、謎
// check()
}
func check() {
for i := 0; i < 0; i++ {
n := rand.Intn(1e9)+1e9
x := solve(n)
y := solve2(n)
if x != y {
println("n=",n)
println("solve=",x)
println("solve2=",y)
panic("OMG!")
}
}
{
n := int(1e10)+1
x := solve(n)
y := solve2(n)
if x != y {
println("n=",n)
println("solve=",x)
println("solve2=",y)
panic("OMG!")
}
}
}
func main() {
var n int
Scan(&n)
ans := solve(n)
Println(ans)
}
func solve(n int) int {
MI := NewInt(M)
rs := make([]int, 1e6+1)
for i := range rs {
if i > 0 {
rs[i] = int(new(Int).ModInverse(NewInt(int64(i)), MI).Int64())
}
}
// 2乗和
// x * (x + 1) * (2 * x + 1) / 6
p2sum := func(x int) int {
x %= M
return x*(x+1)%M*(2*x%M+1)%M * rs[6] % M
}
// 3乗和
// x * x * (x + 1) * (x + 1) / 4
p3sum := func(x int) int {
x %= M
return x*x%M*(x+1)%M*(x+1)%M * rs[4] % M
}
// 4乗和
// x * (x + 1) * (2 * x + 1) * (3 * x^2 + 3 * x - 1) / 30
p4sum := func(x int) int {
x %= M
return x*(x+1)%M*(2*x%M+1)%M*(((3*x%M*x%M+3*x%M)%M+M-1)%M)%M * rs[30] % M
}
ps := []*P{}
for i := 2; i <= 1e6; i++ {
for k,v := 3,i*i*i; v <= 1e18; k,v =k+1,v*i {
ps=append(ps,&P{v,i,k})
if (int(1e18)+v-1)/v < i {
break
}
}
}
Slice(ps, func(i, j int) bool {
return ps[i].value < ps[j].value
})
ms := make([]int, 60)
for i := range ms {
ms[i] = 1
}
ans := 0
next := 1
mm := 1
for next <= n {
var last *P
if len(ps) > 0 {
last = ps[0]
ps = ps[1:]
} else {
last = &P{ n+1, 1, 1 }
}
lower := int(new(Int).Sqrt(NewInt(int64(next))).Int64())
upper := int(new(Int).Sqrt(NewInt(int64(min(n,last.value-1)))).Int64())
{
a := next
b := min((lower+1)*(lower+1)-1, min(n, last.value-1))
var t int
t = b*(b+1)%M*rs[2]%M
t += M - a*(a-1)%M*rs[2]%M
t %= M
ans += t*(lower%M)%M*mm%M
ans %= M
}
if lower+1 < upper {
// lower < s < upper の各sにおいて
// sqrt(X) == s かつ s*s == X
// の X から
// (s+1)*(s+1) == s*s + 2*s + 1 == Y の Y まで
// ans += ((s+1)*(s+1) + 1) * (s+1) * mm
// ans += ((s+1)*(s+1) + 0) * (s+1) * mm
// ans += (s*s + 2*s) * s * mm
// ...
// ans += (s*s + 2) * s * mm
// ans += (s*s + 1) * s * mm
// ans += (s*s + 0) * s * mm
// X から Y まで 2*s+1 個 (Yを含まず)
// この区間、 sqrt(?) は s である
// X から 2*s+1 個までの和 t は
// s*s + 0 から s*s + 2*s で
// t = Σ{i=0,2*s}(s*s + i)
// = s*s * (2*s+1) + (2*s)*((2*s)+1)/2
// = 2*s^3 + s^2 + s*(2*s+1)
// = 2*s^3 + s^2 + 2*s^2 + s
// = 2*s^3 + 3*s^2 + s
// lower < s < upper までは mm は共通なので
// 各 s の t*s を計算して合計すればよいので
// t*s を展開すると
// t*s = (2*s^3 + 3*s^2 + s)*s
// = 2*s^4 + 3*s^3 + s^2
// 2乗の和の公式、3乗の和の公式、4乗の和の公式
// を使えば、lower < s < upper をまとめて計算できるハズ
// 2乗和
// x * (x + 1) * (2 * x + 1) / 6
// 3乗和
// x * x * (x + 1) * (x + 1) / 4
// 4乗和
// x * (x + 1) * (2 * x + 1) * (3 * x^2 + 3 * x - 1) / 30
s := upper-1
tssum := ((2*p4sum(s)%M+3*p3sum(s)%M)%M+p2sum(s))%M
lwsum := ((2*p4sum(lower)%M+3*p3sum(lower)%M)%M+p2sum(lower))%M
ans += (tssum+M-lwsum)%M*mm%M
ans %= M
if DEBUG {
sum := 0
for s = lower+1; s < upper; s++ {
a := s*s
b := (s+1)*(s+1)-1
t := b*(b+1)%M*rs[2]%M
t += M - a*(a-1)%M*rs[2]%M
t %= M
sum = (sum+t*s%M)%M
}
tmp := (tssum+M-lwsum)%M
if tmp != sum {
println("lower=",lower)
println("upper=",upper)
println("tmp=",tmp)
println("sum=",sum)
panic("OMG!")
}
}
}
if lower < upper {
a := upper*upper
b := min((upper+1)*(upper+1)-1, min(n, last.value-1))
if a <= b {
var t int
t = b*(b+1)%M*rs[2]%M
t += M - a*(a-1)%M*rs[2]%M
t %= M
ans += t*(upper%M)%M*mm%M
ans %= M
}
}
next = last.value
mm = mm*rs[ms[last.power]]%M*last.base%M
ms[last.power] = last.base
for len(ps) > 0 && last.value == ps[0].value {
p := ps[0]
ps = ps[1:]
mm = mm*rs[ms[p.power]]%M*p.base%M
ms[p.power] = p.base
}
}
return ans
}
type P struct { value, base, power int }
// 9 AC したやつ
func solve2(n int) int {
sqrt := func(x int) int {
return int(new(Int).Sqrt(NewInt(int64(x))).Int64())
}
ps := []*P{}
for i := 2; i <= 1e6; i++ {
for k,v := 3,i*i*i; v <= 1e18; k,v =k+1,v*i {
ps=append(ps,&P{v,i,k})
if int(1e18+v-1)/v < i {
break
}
}
}
Slice(ps, func(i, j int) bool {
return ps[i].value < ps[j].value
})
ms := make([]int, 60)
for i := range ms {
ms[i] = 1
}
ans := 0
next := 1
for next <= n && len(ps) > 0 {
last := ps[0]
ps = ps[1:]
mm := 1
for _, m := range ms[3:] {
mm = mm*m%M
}
lower := sqrt(next)
upper := sqrt(last.value)+1
for s := lower; s <= upper; s++ {
a := max(s*s, next)
b := min((s+1)*(s+1)-1, min(n,last.value-1))
if a > n || a > last.value-1 {
break
}
var t int
if b%2 == 0 {
t = (b/2)%M*((b+1)%M)%M
} else {
t = ((b+1)/2)%M*(b%M)%M
}
if a%2 == 0 {
t += M - (a/2)%M*((a-1)%M)%M
} else {
t += M - ((a-1)/2)%M*(a%M)%M
}
t %= M
// Printf("%#v\n", ms)
// Printf("n=%d,next=%d,a=%d,b=%d,s=%d,t=%d,mm=%d,lv=%d,ans=%d,tsm=%d\n",n,next,a,b,s,t,mm,last.value,ans,t*s%M*mm%M)
ans += t*s%M*mm%M
ans %= M
}
next = last.value
ms[last.power] = last.base
for len(ps) > 0 && last.value == ps[0].value {
p := ps[0]
ps = ps[1:]
ms[p.power] = p.base
}
}
return ans
}
ID 21712