from math import isqrt n, m = map(int, input().split()) def decomp(n, m): k = 1 while k <= n: q = m // k nk = min(n + 1, m // q + 1) if q > 0 else n + 1 yield k, nk, q k = nk def sum_range(l, r): return (l + r - 1) * (r - l) // 2 ans = sum(m * (r - l) - sum_range(l, r) * q for l, r, q in decomp(n, m)) print(ans % 998244353)