import sys def I(): return int(sys.stdin.readline().rstrip()) def MI(): return map(int,sys.stdin.readline().rstrip().split()) def LI(): return list(map(int,sys.stdin.readline().rstrip().split())) def LI2(): return list(map(int,sys.stdin.readline().rstrip())) def S(): return sys.stdin.readline().rstrip() def LS(): return list(sys.stdin.readline().rstrip().split()) def LS2(): return list(sys.stdin.readline().rstrip()) n,m = MI() mod = 998244353 N = int(n**.5) ans = 0 for i in range(1,min(N,m)+1): x = n//i ans += i*x*(x+3)//2 ans %= mod for x in range(N,0,-1): left = max(n//(x+1),min(N,m))+1 right = min(m,n//x) if left > right: continue ans += x*(x+3)//2*(left+right)*(right-left+1)//2 ans %= mod print(ans)