n,k = map(int,input().split()) r = (n - k + 1 + n) * k // 2 + 1 print(r)