n, k = map(int, input().split()) print(k * (2 * n - k + 1) // 2 + 1)