n,k = map(int,input().split())
r = (n - k + 1 + n) * k // 2 + 1
print(r)