l, k = map(int, input().split()) c = (l-1) // (k*2) print(c * k)