n,m=map(int,input().split());i=n;a=0 while i>0:v=m//i;j=m//(v+1);a+=(m%(j+1)+m%i)*(i-j);i=j print(a//2%998244353)