## https://yukicoder.me/problems/no/1385

import math

MOD = 998244353

inv2 = pow(2, MOD - 2, MOD)

def sum_k(n):
    ans = (n * (n + 1)) % MOD
    ans *= inv2
    ans %= MOD
    return ans

def solve_normal(N, M):
    answer = 0
    for k in range(1, N + 1):
        n = (M % k)
        answer += n
        answer %= MOD
    return answer

def solve_true(N, M):
    sqrt_n = int(math.sqrt(M))

    # 前半は正規に解く
    answer1 = 0
    for k in range(1, min(N, sqrt_n) + 1):
        n = (M % k)
        answer1 += n
        answer1 %= MOD

    # 後半は商に注目
    start = N
    answer2 = 0
    for q in range(sqrt_n + 1):
        end = M // (q + 1)
        end = max(sqrt_n, end)
        if start > end:
            # [start, end)が商qの範囲
            if q == 0:
                n1 = (N - M) % MOD
                ans = (n1 * M) % MOD
                answer2 += ans
                answer2 %= MOD
            else:
                init_s = M // q
                init_r = M % q

                a1 = (-q * sum_k(start)) % MOD
                b1 = (init_s * q) % MOD
                b1 += init_r
                b1 %= MOD
                b1 *= start
                b1 %= MOD

                a2 = (-q * sum_k(end)) % MOD
                b2 = (init_s * q) % MOD
                b2 += init_r
                b2 %= MOD
                b2 *= end
                b2 %= MOD

                ans = (a1 -a2) % MOD
                ans += b1
                ans %= MOD
                ans -= b2
                ans %= MOD

                answer2 += ans
                answer2 %= MOD
            start = end
        if end == sqrt_n:
            break

    return (answer1 + answer2) % MOD




def main():

    N, M = map(int, input().split())

#    ans1 = solve_normal(N, M)
    ans2 = solve_true(N, M)

#    print(ans1)
    print(ans2)



if __name__ == "__main__":
    main()