# これが ★2.5 ですか mod = 998244353 n,m = map(int,input().split()) i = m ans = 0 while i > 0: v = n // i j = n // (v+1) ans -= (2 + 2 * n) * v * (i*(i+1)//2 - j*(j+1)//2) ans += v * (i*(i+1)*(2*i+1)//6 - j*(j+1)*(2*j+1)//6) ans += n * (i-j) ans += n * n * (i-j) ans += v * v * (i*(i+1)*(2*i+1)//6 - j*(j+1)*(2*j+1)//6) ans %= mod i = j print(ans * pow(2, mod-2, mod) % mod)