l, k = map(int, input().split()) c = l-1 // (k*2) print(c * k)