n,m=map(int,input().split());i=n;a=0 while i>0:j=m//(m//i+1);a+=(m%-~j+m%i)*(i-j)//2;i=j print(a%998244353)