## https://yukicoder.me/problems/no/2127 import math MOD = 998244353 def main(): N, M = map(int, input().split()) sqrt_n = int(math.sqrt(N)) inv_pow2 = pow(2, MOD - 2, MOD) inv_pow4 = pow(4, MOD - 2, MOD) inv_pow12 = pow(12, MOD - 2, MOD) # j がsqrt(N)以下のケース answer1 = 0 for j in range(1, min(M, sqrt_n) + 1): r = N % j q = N // j ans1 = ((j - 1) * j) % MOD ans1 *= inv_pow2 ans1 %= MOD ans1 *= q ans1 %= MOD ans2 = ((r + 1) * r) % MOD ans2 *= inv_pow2 ans2 %= MOD ans = (ans1 + ans2) % MOD answer1 += ans answer1 %= MOD # q がsqrt(N)以下のケース answer2 = 0 if M > sqrt_n: def sum1(n): ans = (n * (n + 1)) % MOD ans *= (n - 1) ans %= MOD ans *= pow(6, MOD - 2, MOD) ans %= MOD return ans def sum2(n, N, q): ans1 = (q * q) % MOD ans1 *= n ans1 %= MOD ans1 *= (n + 1) ans1 %= MOD ans1 *= (((2 * n) % MOD) + 1) % MOD ans1 %= MOD ans1 *= inv_pow12 ans1 %= MOD ans2 = (2 * N + 1) % MOD ans2 *= q ans2 %= MOD ans2 *= n ans2 %= MOD ans2 *= (n + 1) ans2 %= MOD ans2 *= inv_pow4 ans2 %= MOD ans3 = (N * (N + 1)) % MOD ans3 *= n ans3 %= MOD ans3 *= inv_pow2 ans3 %= MOD answer = (ans1 - ans2) % MOD answer += ans3 answer %= MOD return answer add_start = False sqrt_n2 = N // sqrt_n for q in range(sqrt_n2 + 2): if q == 0: lower = N if M <= lower: continue add_start = True upper = M else: lower = max(N // (q + 1), sqrt_n) if not add_start: if M <= lower: continue add_start = True upper = M else: upper = max(N // q, sqrt_n) ans = (sum1(upper) - sum1(lower)) % MOD ans *= q ans %= MOD ans2 = (sum2(upper, N, q) - sum2(lower, N, q)) % MOD ans += ans2 ans %= MOD answer2 += ans answer2 %= MOD answer = (answer1 + answer2) % MOD print(answer) if __name__ == "__main__": main()