# M mod k = M - k * (M // k) N, M = map(int, input().split()) # i <= M // k < i + 1 # M / (i + 1) < k <= M / i # ceil((M + 1) / (i + 1)) <= k <= floor(M / i) ans = 0 k = 1 while M // k > k: if k <= N: ans += M - k * (M // k) k += 1 for i in range(0, M // k + 1): l = min(N + 1, (M + 1 + i) // (i + 1)) r = min(N + 1, M // i + 1) if i >= 1 else N + 1 ans += M * (r - l) - (r - l) * (l + r - 1) // 2 * i print(ans % 998244353)