mod = 998244353 n,m = map(int,input().split()) i = m S = lambda x:x*(x+1)//2 T = lambda x:x*(x+1)*(2*x+1)//6 ans = 0 while i > 0: v = n//i j = n//(v+1) ans += n*(n+1)*(i-j)-(2+2*n)*v*(S(i)-S(j))+v*(v+1)*(T(i)-T(j)) i = j print(ans*pow(2,mod-2,mod)%mod)