結果

問題 No.3505 Sum of Prod of Root
コンテスト
ユーザー ID 21712
提出日時 2026-04-20 18:42:21
言語 Go
(1.26.1)
コンパイル:
env GOCACHE=/tmp go build _filename_
実行:
./Main
結果
WA  
実行時間 -
コード長 3,285 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 11,055 ms
コンパイル使用メモリ 287,008 KB
実行使用メモリ 64,536 KB
最終ジャッジ日時 2026-04-20 18:42:45
合計ジャッジ時間 19,254 ms
ジャッジサーバーID
(参考情報)
judge1_0 / judge2_1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 7 WA * 6
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

package main

import . "fmt"
import . "sort"
import . "math"
import . "math/big"

const M = 998244353

// func max(a,b int) int { if a>b { return a; } else { return b; } }; func min(a,b int) int { if a<b { return a; } else { return b; } }

func main() {
	var n int
	Scan(&n)
	ans := solve(n)
	Println(ans)
}

func solve(n int) int {
	MI := NewInt(M)
	rs := make([]int, 1e6+1)
	for i := range rs {
		if i > 0 {
			rs[i] = int(new(Int).ModInverse(NewInt(int64(i)), MI).Int64())
		}
	}

	// 2乗和
	//  x * (x + 1) * (2 * x + 1) / 6
	p2sum := func(x int) int {
		x %= M
		return x*(x+1)%M*(2*x%M+1)%M * rs[6] % M
	}
	// 3乗和
	//  x * x * (x + 1) * (x + 1) / 4
	p3sum := func(x int) int {
		x %= M
		return x*x%M*(x+1)%M*(x+1)%M * rs[4] % M
	}
	// 4乗和
	//  x * (x + 1) * (2 * x + 1) * (3 * x^2 + 3 * x - 1) / 30	
	p4sum := func(x int) int {
		x %= M
		return x*(x+1)%M*(2*x%M+1)%M*(((3*x%M*x%M+3*x%M)%M+M-1)%M)%M * rs[30] % M
	}
	
	ps := []*P{}
	
	for i := 2; i <= 1e6; i++ {
		for k,v := 3,i*i*i; v <= 1e18; k,v =k+1,v*i {
			ps=append(ps,&P{v,i,k})
			if int(1e18+v-1)/v < i {
				break
			}
		}
	}
	
	Slice(ps, func(i, j int) bool {
		return ps[i].value < ps[j].value
	})
	

	ms := make([]int, 60)
	for i := range ms {
		ms[i] = 1
	}
	
	ans := 0
	next := 1
	mm := 1
	for next <= n && len(ps) > 0 {
		last := ps[0]
		ps = ps[1:]
		lower := int(Sqrt(float64(next)))
		upper := int(Sqrt(float64(min(n,last.value-1))))
		{
			a := next
			b := min((lower+1)*(lower+1)-1, min(n, last.value-1))
			var t int
			t = b*(b+1)%M*rs[2]%M
			t += M - a*(a-1)%M*rs[2]%M
			t %= M
			ans += t*lower%M*mm%M
			ans %= M
		}
		if lower+1 < upper {
			// lower < s < upper の各sにおいて
			// sqrt(X) == s かつ s*s == X
			// の X から 
			// (s+1)*(s+1) == s*s + 2*s + 1 == Y の Y まで
			// X から Y まで 2*s 個 (Yを含まず)
			// この区間、 sqrt(?) は s である
			// X から 2*s+1 個までの和 t は
			// s*s + 0 から s*s + 2*s で
			// t = s*s * (2*s+1) + (2*s)*((2*s)+1)/2
			// lower < s < upper までは mm は共通なので
			// 各 s の t*s を計算して合計すればよいので
			// t*s を展開すると
			// t*s = (s*s * (2*s+1) + (2*s)*((2*s)+1)/2)*s
			//     = 2*s^4 + 3*s^3 + s^2
			// 2乗の和の公式、3乗の和の公式、4乗の和の公式
			// を使えば、lower < s < upper をまとめて計算できるハズ
			// 2乗和
			//  x * (x + 1) * (2 * x + 1) / 6
			// 3乗和
			//  x * x * (x + 1) * (x + 1) / 4
			// 4乗和
			//  x * (x + 1) * (2 * x + 1) * (3 * x^2 + 3 * x - 1) / 30
			s := upper-1
			tssum := ((2*p4sum(s)%M+3*p3sum(s)%M)%M+p2sum(s))%M
			lwsum :=((2*p4sum(lower)%M+3*p3sum(lower)%M)%M+p2sum(lower))%M
			ans += (tssum+M-lwsum)%M*mm%M
			ans %= M
		}
		if lower < upper {
			a := upper*upper
			b := min((upper+1)*(upper+1)-1, min(n, last.value-1))
			if a <= b {
				var t int
				t = b*(b+1)%M*rs[2]%M
				t += M - a*(a-1)%M*rs[2]%M
				t %= M
				ans += t*upper%M*mm%M
				ans %= M
			}
		}
		next = last.value
		mm = mm*rs[ms[last.power]]%M*last.base%M
		ms[last.power] = last.base
		for len(ps) > 0 && last.value == ps[0].value {
			p := ps[0]
			ps = ps[1:]
			mm = mm*rs[ms[p.power]]%M*p.base%M
			ms[p.power] = p.base
		}
	}

	return ans
}

type P struct { value, base, power int }
0