n,l = map(int,input().split()) x = -(-n//l) print(2**x-1)