n, k = map(int, input().split()) if n == 0 and k >= 1: print("INF") else: # Compute the bits not set in n (M's bits) m_bits = [] for bit in range(31): if not (n & (1 << bit)): m_bits.append(bit) # Split into high and low bits high_bits = [] low_bits = [] for bit in m_bits: if (1 << bit) > k: high_bits.append(bit) else: low_bits.append(bit) # Process low bits in descending order low_bits.sort(reverse=True) # Initialize DP: dp[d] is the number of ways to get difference d dp = [0] * (k + 1) dp[0] = 1 for bit in low_bits: new_dp = [0] * (k + 1) val = 1 << bit for d in range(k + 1): if dp[d] == 0: continue # Assign to a new_d = d - val if new_d >= 0 and new_d <= k: new_dp[new_d] += dp[d] # Assign to b new_d = d + val if new_d <= k: new_dp[new_d] += dp[d] # Assign to neither new_dp[d] += dp[d] dp = new_dp total = sum(dp) print(total)