n,k=map(int,input().split()) print(min(n,k+1) if n%2 else min(n//2,k+1))