結果

問題 No.3505 Sum of Prod of Root
コンテスト
ユーザー ID 21712
提出日時 2026-04-21 01:47:42
言語 Go
(1.26.1)
コンパイル:
env GOCACHE=/tmp go build _filename_
実行:
./Main
結果
TLE  
実行時間 -
コード長 7,246 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 15,362 ms
コンパイル使用メモリ 271,560 KB
実行使用メモリ 140,856 KB
最終ジャッジ日時 2026-04-21 01:49:02
合計ジャッジ時間 22,026 ms
ジャッジサーバーID
(参考情報)
judge2_0 / judge1_0
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 1
other AC * 3 TLE * 1 -- * 9
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

package main

import . "fmt"
import . "sort"
// import . "math"
import . "math/big"
import "math/rand"

const DEBUG = true

const M = 998244353

// func max(a,b int) int { if a>b { return a; } else { return b; } }; func min(a,b int) int { if a<b { return a; } else { return b; } }

func init() {
	// check()
}

func check() {
	for i := 0; i < 0; i++ {
		n := rand.Intn(1e9)+1e9
		x := solve(n)
		y := solve2(n)
		if x != y {
			println("n=",n)
			println("solve=",x)
			println("solve2=",y)
			panic("OMG!")
		}
	}
	{	
		n := int(1e10)+1
		x := solve(n)
		y := solve2(n)
		if x != y {
			println("n=",n)
			println("solve=",x)
			println("solve2=",y)
			panic("OMG!")
		}
	}
}

func main() {
	var n int
	Scan(&n)
	ans := solve(n)
	Println(ans)
}

func solve(n int) int {
	MI := NewInt(M)
	rs := make([]int, 1e6+1)
	for i := range rs {
		if i > 0 {
			rs[i] = int(new(Int).ModInverse(NewInt(int64(i)), MI).Int64())
		}
	}

	// 2乗和
	//  x * (x + 1) * (2 * x + 1) / 6
	p2sum := func(x int) int {
		x %= M
		return x * (x+1) %M * (2*x%M + 1)%M * rs[6] % M
	}
	// 3乗和
	//  x * x * (x + 1) * (x + 1) / 4
	p3sum := func(x int) int {
		x %= M
		return x * x%M * (x+1)%M * (x+1)%M * rs[4] % M
	}
	// 4乗和
	//  x * (x + 1) * (2 * x + 1) * (3 * x^2 + 3 * x - 1) / 30	
	p4sum := func(x int) int {
		x %= M
		return x * (x+1)%M * (2*x%M + 1)%M * ( ( ( 3*x%M*x%M + 3*x%M )%M + M-1 ) % M ) % M * rs[30] % M
	}
	
	ps := []*P{}
	
	for i := 2; i <= 1e6; i++ {
		for k,v := 3,i*i*i; v <= 1e18; k,v =k+1,v*i {
			ps=append(ps,&P{v,i,k})
			if (int(1e18)+v-1)/v < i {
				break
			}
		}
	}
	
	Slice(ps, func(i, j int) bool {
		return ps[i].value < ps[j].value
	})
	

	ms := make([]int, 60)
	for i := range ms {
		ms[i] = 1
	}
	
	ans := 0
	next := 1
	mm := 1
	for next <= n  {
		var last *P
		if len(ps) > 0 {
			last = ps[0]
			ps = ps[1:]
		} else {
			last = &P{ n+1, 1, 1 }
		}
		lower := int(new(Int).Sqrt(NewInt(int64(next))).Int64())
		upper := int(new(Int).Sqrt(NewInt(int64(min(n,last.value-1)))).Int64())
		{
			a := next
			b := min((lower+1)*(lower+1)-1, min(n, last.value-1))
			var t int
			t = b%M*(b%M+1)%M*rs[2]%M
			t += M - a%M*(a%M-1)%M*rs[2]%M
			t %= M
			ans += t*(lower%M)%M*mm%M
			ans %= M
		}
		if lower+1 < upper {
			// lower < s < upper の各sにおいて
			// sqrt(X) == s かつ s*s == X
			// の X から 
			// (s+1)*(s+1) == s*s + 2*s + 1 == Y の Y まで
			//   ans += ((s+1)*(s+1) + 1) * (s+1) * mm
			//   ans += ((s+1)*(s+1) + 0) * (s+1) * mm
			//   ans += (s*s + 2*s) * s * mm
			//    ...
			//   ans += (s*s +   2) * s * mm
			//   ans += (s*s +   1) * s * mm
			//   ans += (s*s +   0) * s * mm
			// X から Y まで 2*s+1 個 (Yを含まず)
			// この区間、 sqrt(?) は s である
			// X から 2*s+1 個までの和 t は
			// s*s + 0 から s*s + 2*s で
			// t = Σ{i=0,2*s}(s*s + i)
			//   = s*s * (2*s+1) + (2*s)*((2*s)+1)/2
			//   = 2*s^3 + s^2 + s*(2*s+1)
			//   = 2*s^3 + s^2 + 2*s^2 + s
			//   = 2*s^3 + 3*s^2 + s
			// lower < s < upper までは mm は共通なので
			// 各 s の t*s を計算して合計すればよいので
			// t*s を展開すると
			// t*s = (2*s^3 + 3*s^2 + s)*s
			//     = 2*s^4 + 3*s^3 + s^2
			// 2乗の和の公式、3乗の和の公式、4乗の和の公式
			// を使えば、lower < s < upper をまとめて計算できるハズ
			// 2乗和
			//  x * (x + 1) * (2 * x + 1) / 6
			// 3乗和
			//  x * x * (x + 1) * (x + 1) / 4
			// 4乗和
			//  x * (x + 1) * (2 * x + 1) * (3 * x^2 + 3 * x - 1) / 30
			s := upper-1
			tssum := ( ( 2*p4sum(s)%M + 3*p3sum(s)%M ) %M + p2sum(s) ) % M
			lwsum := ( ( 2*p4sum(lower)%M + 3*p3sum(lower)%M ) %M + p2sum(lower) ) % M
			ans += (tssum+M-lwsum)%M * mm%M
			ans %= M
			if DEBUG {
				sum := 0
				for s = lower+1; s < upper; s++ {
					a := s*s
					b := (s+1)*(s+1)-1
					t := b%M*(b%M+1)%M*rs[2]%M
					t += M - a%M*(a%M-1)%M*rs[2]%M
					t %= M
					sum = (sum+t*s%M)%M
				}
				tmp := (tssum+M-lwsum)%M
				if tmp != sum {
					println("lower=",lower)
					println("upper=",upper)
					println("tmp=",tmp)
					println("sum=",sum)
					println("tssum=",tssum)
					println("lwsum=",lwsum)
					
					sum = 0
					for s = lower+1; s < upper; s++ {
						a := s*s
						b := (s+1)*(s+1)-1
						t := b%M*(b%M+1)%M*rs[2]%M
						t += M - a%M*(a%M-1)%M*rs[2]%M
						t %= M
						sum = (sum+t*s%M)%M
						println("s=",s,",a=",a,",b=",b,"t=",t,"sum=",sum)
						tssum = ( ( 2*p4sum(s)%M + 3*p3sum(s)%M ) %M + p2sum(s) ) % M
						lwsum = ( ( 2*p4sum(lower)%M + 3*p3sum(lower)%M ) %M + p2sum(lower) ) % M	
						println("tssum=",tssum,",lwsum=",lwsum,",t-l=",(tssum+M-lwsum)%M)
					}

					panic("OMG!")
				}
				// a := s*s
				// b := (s+1)*(s+1)-1
				// t := b*(b+1)/2
				//      - a*(a-1)/2
				// ts := t * s
				// これを整理する
				// 2 * t = b*(b+1) - a*(a-1)
				//       = ((s+1)*(s+1)-1)*((s+1)*(s+1))
				//           - (s*s)*(s*s-1)
				//       = (s^2 + 2*s)*(s^2 + 2*s + 1)
				//           - (s^4 - s^2)
				//       = s^2 * (s^2 + 2*s + 1)
				//           + 2*s * (s^2 + 2*s + 1)
				//           - (s^4 - s^2)
				//       = s^4 + 2*s^3 + s^2
				//           + 2*s^3 + 4*s^2 + 2*s
				//           - s^4 + s^2
				//       = 4*s^3 + 6*s^2 + 2*s
				//     t = 2*s^3 + 3*s^2 + s
				// s * t = 2*s^4 + 3*s^3 + s^2
				// 意味なし
			}
		}
		if lower < upper {
			a := upper*upper
			b := min((upper+1)*(upper+1)-1, min(n, last.value-1))
			if a <= b {
				var t int
				t = b%M*(b%M+1)%M*rs[2]%M
				t += M - a%M*(a%M-1)%M*rs[2]%M
				t %= M
				ans += t*(upper%M)%M*mm%M
				ans %= M
			}
		}
		next = last.value
		mm = mm*rs[ms[last.power]]%M*last.base%M
		ms[last.power] = last.base
		for len(ps) > 0 && last.value == ps[0].value {
			p := ps[0]
			ps = ps[1:]
			mm = mm*rs[ms[p.power]]%M*p.base%M
			ms[p.power] = p.base
		}
	}

	return ans
}

type P struct { value, base, power int }

// 9 AC したやつ
func solve2(n int) int {
	
	sqrt := func(x int) int {
		return int(new(Int).Sqrt(NewInt(int64(x))).Int64())
	}
	
	ps := []*P{}
	
	for i := 2; i <= 1e6; i++ {
		for k,v := 3,i*i*i; v <= 1e18; k,v =k+1,v*i {
			ps=append(ps,&P{v,i,k})
			if int(1e18+v-1)/v < i {
				break
			}
		}
	}
	
	Slice(ps, func(i, j int) bool {
		return ps[i].value < ps[j].value
	})
	

	ms := make([]int, 60)
	for i := range ms {
		ms[i] = 1
	}
	
	ans := 0
	next := 1
	for next <= n && len(ps) > 0 {
		last := ps[0]
		ps = ps[1:]
		mm := 1
		for _, m := range ms[3:] {
			mm = mm*m%M
		}
		lower := sqrt(next)
		upper := sqrt(last.value)+1
		for s := lower; s <= upper; s++ {
			a := max(s*s, next)
			b := min((s+1)*(s+1)-1, min(n,last.value-1))
			if a > n || a > last.value-1 {
				break
			}
			var t int
			if b%2 == 0 {
				t = (b/2)%M*((b+1)%M)%M
			} else {
				t = ((b+1)/2)%M*(b%M)%M
			}
			if a%2 == 0 {
				t += M - (a/2)%M*((a-1)%M)%M
			} else {
				t += M - ((a-1)/2)%M*(a%M)%M
			}
			t %= M
			// Printf("%#v\n", ms)
			// Printf("n=%d,next=%d,a=%d,b=%d,s=%d,t=%d,mm=%d,lv=%d,ans=%d,tsm=%d\n",n,next,a,b,s,t,mm,last.value,ans,t*s%M*mm%M)
			ans += t*s%M*mm%M
			ans %= M
		}
		next = last.value
		ms[last.power] = last.base
		for len(ps) > 0 && last.value == ps[0].value {
			p := ps[0]
			ps = ps[1:]
			ms[p.power] = p.base
		}
	}

	return ans
}

0