x, n = map(int, input().split()) ng, ok = -1, 2 * 10**9 while ok - ng > 1: mid = (ok + ng) // 2 if mid * (mid + 1) // 2 > x: ok = mid else: ng = mid k = min(n, ok) x -= k * (k + 1) // 2 rem = (n - k) // 2 x -= rem if k + rem * 2 != n: x += n print(x)