MOD = 998244353 n, m = map(int, input().split()) D = min(n, m) S1 = 0 S2 = 0 d = 1 while d <= D: q = n // d if q == 0: break max_d = n // q current_end = min(max_d, D) num_terms = current_end - d + 1 a = d l = current_end sum_d = (a + l) * num_terms // 2 contribution_S2 = q * sum_d contribution_S1 = (q * (q + 1) // 2) * sum_d S2 = (S2 + contribution_S2) % MOD S1 = (S1 + contribution_S1) % MOD d = current_end + 1 total = (S1 + S2) % MOD print(total)