def fast(n): ans = 0 x = 1 while min(n, x*x) != n: ans += x * (x-1) * (x+1) // 6 x += 1 lb = x ub = n + 1 k = 1 while k <= n: tub = (n-1) // k tlb = x if tlb > tub: break vlb = k * tub + 1 vub = min(n + 1, k * (tub + 1) + 1) ans += (vub - vlb) * (vub - vlb + 1) // 2 tub -= 1 if tlb <= tub: ans += (tub - tlb + 1) * k * (k+1) // 2 k += 1 for x in range(1, n+1): ub = n lb = min(n, x*x) if lb == ub: break ubv = ub//x ubr = ub%x lbv = lb//x lbr = lb%x ans += x * (x-1) // 2 * ubv + ubr * (ubr+1) // 2 ans -= x * (x-1) // 2 * lbv + lbr * (lbr+1) // 2 return ans n = int(input()) print(fast(n) % 998244353)