n,m = map(int,input().split()) i = n a = 0 while i > 0: v = m // i j = m // (v+1) a += (m%(j+1) + m%i) * (i-j) i = j print(a//2%998244353)