def main(): N,D=map(int,input().split()) v=list([n*n for n in range(1,N+1)]) p=[0]*8000001 for x in v: for y in v: p[x+y]+=1 sm=0 for w in v: for z in v: n=w-z+D if n<1: break sm+=p[n] print(sm) main()