import math n, m = map(int, input().split()) mod = 998244353 t = (int)(math.sqrt(n)) ans = n * (n + 1) // 2 * m % mod j = 1 while j < t and j <= m: nj = n // j ans -= j * nj * (n + 1) ans %= mod ans += nj * (nj + 1) // 2 * j * j ans %= mod j += 1 nj, nt = 1, n // t while nj <= nt: l = max(t, (n + nj + 1) // (nj + 1)) r = min(m, n // nj) if l > r: nj += 1 continue ll = l * (l - 1) // 2 rr = r * (r + 1) // 2 ans -= (n + 1) * nj * (rr - ll) ans %= mod ans += nj * (nj + 1) * (rr * (2 * r + 1) - ll * (2 * l - 1)) // 6 ans %= mod nj += 1 if ans < 0: ans += mod print(ans)