U = 18 N, K = map(int, input().split()) if N < K: print("INF") exit() bit = [] for i in reversed(range(U)): if ~(N >> i) & 1: bit.append(1 << i) bit.reverse() sz = len(bit) val = [0] * (1 << sz) for S in range(1 << sz): for i, b in enumerate(bit): if (S >> i) & 1: val[S] += b ans = 0 mask_all = (1 << sz) - 1 for S, x in enumerate(val): T = mask_all ^ S mask = T while True: if 0 <= x - val[mask] <= K: ans += 1 if not mask: break mask = (mask - 1) & T print(ans)