def get(): return list(map(int, input().strip().split(' '))) N, K = get() count = 0 def counting(n, k): count = 0 s = [k + 1] while len(s) > 0: x = s.pop() if x <= n: count += 1 s.append(x + 1) s.append(x + 1) return count count = counting(N, K) if K == 1: count -= 1 print(count)