def solve(): n,d=map(int,input().split()) p=[] for i in range(1,n+1): p.append(i*i) cnt=[0]*8000001 for i in p: for j in p: s=i-j+d if s>=0: cnt[s]+=1 ans=0 for i in p: for j in p: ans+=cnt[i+j] print(ans) solve()