n,m=map(int,input().split()) a=0 b=998244353 s=0 def c(x): s=(min(m+1,n//(x+1)+1)+min(m,n//x))*(min(m,n//x)-min(m,n//(x+1)))*499122177 s%=b return (s*x*(x+3)*499122177)%b for x in range(1,int(n**.5)+1): a+=c(x) a%=b if n//x!=x : a+=clac(n//x) a%=b print(a)