import sys input = sys.stdin.readline N, M = map(int, input().split()) m = M.bit_length() - 1 if N - 1 <= m: print(pow(2, N) - 1) exit(0) res = 0 c = 0 while M > 1: res += M M += 1 M >>= 1 c += 1 #print(M, c, res) res += N - c print(res)