## https://yukicoder.me/problems/no/1385 import math MOD = 998244353 inv2 = pow(2, MOD - 2, MOD) def sum_k(n): ans = (n * (n + 1)) % MOD ans *= inv2 ans %= MOD return ans def solve_normal(N, M): answer = 0 for k in range(1, N + 1): n = (M % k) answer += n answer %= MOD return answer def solve_true(N, M): sqrt_n = int(math.sqrt(M)) # 前半は正規に解く answer1 = 0 for k in range(1, min(N, sqrt_n) + 1): n = (M % k) answer1 += n answer1 %= MOD # 後半は商に注目 start = N answer2 = 0 for q in range(sqrt_n + 1): end = M // (q + 1) end = max(sqrt_n, end) if start > end: # [start, end)が商qの範囲 if q == 0: n1 = (N - M) % MOD ans = (n1 * M) % MOD answer2 += ans answer2 %= MOD else: init_s = M // q init_r = M % q a1 = (-q * sum_k(start)) % MOD b1 = (init_s * q) % MOD b1 += init_r b1 %= MOD b1 *= start b1 %= MOD a2 = (-q * sum_k(end)) % MOD b2 = (init_s * q) % MOD b2 += init_r b2 %= MOD b2 *= end b2 %= MOD ans = (a1 -a2) % MOD ans += b1 ans %= MOD ans -= b2 ans %= MOD answer2 += ans answer2 %= MOD start = end if end == sqrt_n: break return (answer1 + answer2) % MOD def main(): N, M = map(int, input().split()) # ans1 = solve_normal(N, M) ans2 = solve_true(N, M) # print(ans1) print(ans2) if __name__ == "__main__": main()