n,m = map(int,input().split()) a = 0 i = 1 while i*m <= n: if n % (i*m) == 0: a += 1 i += 1 print(a)