n, k = map(int, input().split())

n -= k
f = 1
ans = 0
while f * f <= n:
    if n % f == 0:
        if (n + k) % f == k:
            ans += 1
        if n // f != f:
            if (n + k) % (n // f) == k:
                ans += 1
    f += 1
print(ans)