def count(n): m = int(n ** 0.5) + 2 cnt = 0 s = set() for i in range(1,m): if n % i == 0 and i not in s: if i > k: cnt += 1 s.add(i) if n // i != i and n // i > k: cnt += 1 s.add(n // i) return cnt n,k = map(int,input().split()) ans = count(n-k) print(ans)