import sys def input(): return sys.stdin.readline().rstrip('\n') def main(): n, d = map(int, input().split()) p = [] m = [] for i in range(1, n+1): for j in range(1, n+1): p.append(i*i+j*j) for i in range(1, n+1): for j in range(1, n+1): m.append(i*i-j*j+d) p.sort() m.sort() i = j = ans = 0 while i < len(p) and j < len(m): if p[i] < m[j]: i += 1 elif p[i] > m[j]: j += 1 else: v = p[i] d = e = 0 while i < len(p) and p[i] == v: i += 1 d += 1 while j < len(m) and m[j] == v: j += 1 e += 1 ans += d*e print(ans) if __name__ == '__main__': main()