l, k = map(int, input().split()) if l % (2 * k) != 0: print(l // (k * 2) * k) else: print((l // (k * 2) - 1) * k)