def getlist(): return list(map(int, input().split())) N, K = getlist() x = N.bit_length() if 2 ** x <= K: print("INF") else: y = K.bit_length() ans = 0 L = [] for i in range(x): if N & 1 == 0: L.append(2 ** i) N >>= 1 for i in range(1 << len(L)): for j in range(len(L)): A = 0 B = 0 if ((i >> j) & 1) == 1: A += L[j] for v in range(1 << len(L)): for w in range(len(L)): if ((i >> j) & 1) == 1: B += L[j] if B - A <= K and A & B == 0: ans += 1 print(ans)