package main import . "fmt" import . "sort" // import . "math" import . "math/big" import "math/rand" const DEBUG = false const M = 998244353 // func max(a,b int) int { if a>b { return a; } else { return b; } }; func min(a,b int) int { if a 0 { rs[i] = int(new(Int).ModInverse(NewInt(int64(i)), MI).Int64()) } } // 2乗和 // x * (x + 1) * (2 * x + 1) / 6 p2sum := func(x int) int { x %= M return x * (x+1) %M * (2*x%M + 1)%M * rs[6] % M } // 3乗和 // x * x * (x + 1) * (x + 1) / 4 p3sum := func(x int) int { x %= M return x * x%M * (x+1)%M * (x+1)%M * rs[4] % M } // 4乗和 // x * (x + 1) * (2 * x + 1) * (3 * x^2 + 3 * x - 1) / 30 p4sum := func(x int) int { x %= M return x * (x+1)%M * (2*x%M + 1)%M * ( ( ( 3*x%M*x%M + 3*x%M )%M + M-1 ) % M ) % M * rs[30] % M } ps := []*P{} for i := 2; i <= 1e6; i++ { for k,v := 3,i*i*i; v <= 1e18; k,v =k+1,v*i { ps=append(ps,&P{v,i,k}) if (int(1e18)+v-1)/v < i { break } } } Slice(ps, func(i, j int) bool { return ps[i].value < ps[j].value }) ms := make([]int, 60) for i := range ms { ms[i] = 1 } ans := 0 next := 1 mm := 1 for next <= n { var last *P if len(ps) > 0 { last = ps[0] ps = ps[1:] } else { last = &P{ n+1, 1, 1 } } lower := int(new(Int).Sqrt(NewInt(int64(next))).Int64()) upper := int(new(Int).Sqrt(NewInt(int64(min(n,last.value-1)))).Int64()) { a := next b := min((lower+1)*(lower+1)-1, min(n, last.value-1)) var t int t = b%M*(b%M+1)%M*rs[2]%M t += M - a%M*(a%M-1)%M*rs[2]%M t %= M ans += t*(lower%M)%M*mm%M ans %= M } if lower+1 < upper { // lower < s < upper の各sにおいて // sqrt(X) == s かつ s*s == X // の X から // (s+1)*(s+1) == s*s + 2*s + 1 == Y の Y まで // ans += ((s+1)*(s+1) + 1) * (s+1) * mm // ans += ((s+1)*(s+1) + 0) * (s+1) * mm // ans += (s*s + 2*s) * s * mm // ... // ans += (s*s + 2) * s * mm // ans += (s*s + 1) * s * mm // ans += (s*s + 0) * s * mm // X から Y まで 2*s+1 個 (Yを含まず) // この区間、 sqrt(?) は s である // X から 2*s+1 個までの和 t は // s*s + 0 から s*s + 2*s で // t = Σ{i=0,2*s}(s*s + i) // = s*s * (2*s+1) + (2*s)*((2*s)+1)/2 // = 2*s^3 + s^2 + s*(2*s+1) // = 2*s^3 + s^2 + 2*s^2 + s // = 2*s^3 + 3*s^2 + s // lower < s < upper までは mm は共通なので // 各 s の t*s を計算して合計すればよいので // t*s を展開すると // t*s = (2*s^3 + 3*s^2 + s)*s // = 2*s^4 + 3*s^3 + s^2 // 2乗の和の公式、3乗の和の公式、4乗の和の公式 // を使えば、lower < s < upper をまとめて計算できるハズ // 2乗和 // x * (x + 1) * (2 * x + 1) / 6 // 3乗和 // x * x * (x + 1) * (x + 1) / 4 // 4乗和 // x * (x + 1) * (2 * x + 1) * (3 * x^2 + 3 * x - 1) / 30 s := upper-1 tssum := ( ( 2*p4sum(s)%M + 3*p3sum(s)%M ) %M + p2sum(s) ) % M lwsum := ( ( 2*p4sum(lower)%M + 3*p3sum(lower)%M ) %M + p2sum(lower) ) % M ans += (tssum+M-lwsum)%M * mm%M ans %= M if DEBUG { sum := 0 for s = lower+1; s < upper; s++ { a := s*s b := (s+1)*(s+1)-1 t := b%M*(b%M+1)%M*rs[2]%M t += M - a%M*(a%M-1)%M*rs[2]%M t %= M sum = (sum+t*s%M)%M } tmp := (tssum+M-lwsum)%M if tmp != sum { println("lower=",lower) println("upper=",upper) println("tmp=",tmp) println("sum=",sum) println("tssum=",tssum) println("lwsum=",lwsum) sum = 0 for s = lower+1; s < upper; s++ { a := s*s b := (s+1)*(s+1)-1 t := b%M*(b%M+1)%M*rs[2]%M t += M - a%M*(a%M-1)%M*rs[2]%M t %= M sum = (sum+t*s%M)%M println("s=",s,",a=",a,",b=",b,"t=",t,"sum=",sum) tssum = ( ( 2*p4sum(s)%M + 3*p3sum(s)%M ) %M + p2sum(s) ) % M lwsum = ( ( 2*p4sum(lower)%M + 3*p3sum(lower)%M ) %M + p2sum(lower) ) % M println("tssum=",tssum,",lwsum=",lwsum,",t-l=",(tssum+M-lwsum)%M) } panic("OMG!") } // a := s*s // b := (s+1)*(s+1)-1 // t := b*(b+1)/2 // - a*(a-1)/2 // ts := t * s // これを整理する // 2 * t = b*(b+1) - a*(a-1) // = ((s+1)*(s+1)-1)*((s+1)*(s+1)) // - (s*s)*(s*s-1) // = (s^2 + 2*s)*(s^2 + 2*s + 1) // - (s^4 - s^2) // = s^2 * (s^2 + 2*s + 1) // + 2*s * (s^2 + 2*s + 1) // - (s^4 - s^2) // = s^4 + 2*s^3 + s^2 // + 2*s^3 + 4*s^2 + 2*s // - s^4 + s^2 // = 4*s^3 + 6*s^2 + 2*s // t = 2*s^3 + 3*s^2 + s // s * t = 2*s^4 + 3*s^3 + s^2 // 意味なし } } if lower < upper { a := upper*upper b := min((upper+1)*(upper+1)-1, min(n, last.value-1)) if a <= b { var t int t = b%M*(b%M+1)%M*rs[2]%M t += M - a%M*(a%M-1)%M*rs[2]%M t %= M ans += t*(upper%M)%M*mm%M ans %= M } } next = last.value mm = mm*rs[ms[last.power]]%M*last.base%M ms[last.power] = last.base for len(ps) > 0 && last.value == ps[0].value { p := ps[0] ps = ps[1:] mm = mm*rs[ms[p.power]]%M*p.base%M ms[p.power] = p.base } } return ans } type P struct { value, base, power int } // 9 AC したやつ func solve2(n int) int { sqrt := func(x int) int { return int(new(Int).Sqrt(NewInt(int64(x))).Int64()) } ps := []*P{} for i := 2; i <= 1e6; i++ { for k,v := 3,i*i*i; v <= 1e18; k,v =k+1,v*i { ps=append(ps,&P{v,i,k}) if int(1e18+v-1)/v < i { break } } } Slice(ps, func(i, j int) bool { return ps[i].value < ps[j].value }) ms := make([]int, 60) for i := range ms { ms[i] = 1 } ans := 0 next := 1 for next <= n && len(ps) > 0 { last := ps[0] ps = ps[1:] mm := 1 for _, m := range ms[3:] { mm = mm*m%M } lower := sqrt(next) upper := sqrt(last.value)+1 for s := lower; s <= upper; s++ { a := max(s*s, next) b := min((s+1)*(s+1)-1, min(n,last.value-1)) if a > n || a > last.value-1 { break } var t int if b%2 == 0 { t = (b/2)%M*((b+1)%M)%M } else { t = ((b+1)/2)%M*(b%M)%M } if a%2 == 0 { t += M - (a/2)%M*((a-1)%M)%M } else { t += M - ((a-1)/2)%M*(a%M)%M } t %= M // Printf("%#v\n", ms) // Printf("n=%d,next=%d,a=%d,b=%d,s=%d,t=%d,mm=%d,lv=%d,ans=%d,tsm=%d\n",n,next,a,b,s,t,mm,last.value,ans,t*s%M*mm%M) ans += t*s%M*mm%M ans %= M } next = last.value ms[last.power] = last.base for len(ps) > 0 && last.value == ps[0].value { p := ps[0] ps = ps[1:] ms[p.power] = p.base } } return ans }