l,k = map(int,input().split()) t = (l - 1) // (2 * k) print(t * k)