N, K = map(int, input().split()) if N == 0: if K >= 1: print("INF") else: print(1) else: highest_bit = N.bit_length() - 1 mask = ((1 << (highest_bit + 1)) - 1) ^ N count = 0 for a in range(mask + 1): if (a & mask) != a: continue remaining_mask = mask ^ a max_c = remaining_mask max_d = max_c - a if max_d < 0: continue max_d_possible = min(K, max_d) for d in range(0, max_d_possible + 1): c = a + d if (c & remaining_mask) == c: count += 1 print(count)