n,k=map(int,input().split()) n-=k s=set() for i in range(2,n**0.5+2): if n%i==0: if i>k: s.add(i) if n//i>k: s.add(n//i) print(len(s))