l, k = map(int, input().split()) x = 1 while l - 2 * x * k > 0: x += 1 print((x - 1) * k)