## https://yukicoder.me/problems/no/1601 import math MOD = 998244353 def main(): N, M = map(int, input().split()) sqrt_m = int(math.sqrt(M)) # i <= sqrt_mの場合 answer = 0 for i in range(1, sqrt_m + 1): l = N // i ans1 = (l * i) % MOD ans2 = (l * (l + 1)) % MOD ans2 *= pow(2, MOD - 2, MOD) ans2 %= MOD ans2 *= i ans2 %= MOD ans = (ans1 + ans2) % MOD ans %= MOD answer += ans answer %= MOD # i > sqrt_mの場合 if sqrt_m < N: b_flg = False upper = float("inf") lower = N for l in range(M + 1): u0 = min(upper, M) l0 = min(lower, M) # print(f"l = {l}, u0 = {u0}, l0 = {l0}") ans1 = (u0 * (u0 + 1)) % MOD ans1 *= pow(2, MOD - 2, MOD) ans1 %= MOD ans2 = (l0 * (l0 + 1)) % MOD ans2 *= pow(2, MOD - 2, MOD) ans2 %= MOD x = (l * (l + 1)) % MOD x *= pow(2, MOD - 2, MOD) x %= MOD ans = (ans1 - ans2) % MOD ans *= (x + l) ans %= MOD answer += ans answer %= MOD # print(answer , ans) if b_flg: break upper = lower lower = N // (l + 2) if lower <= sqrt_m: b_flg = True print(answer) if __name__ == "__main__": main()