N, K = map(int, input().split()) if N < K: print('INF');exit() ans = 0 for x in range(2 ** N.bit_length()): for y in range(x, x + K + 1): if x & y == N: ans += 1 print(ans)