import java.util.*; public class Main { public static void main (String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int d = sc.nextInt(); int max = n * n * 2; int[] counts = new int[max + 1]; int[] sums = new int[max + 1]; for (int i = 1; i <= n; i++) { for (int j = 1; j <= n; j++) { counts[i * i + j * j]++; int tmp = i * i - j * j + d; if (tmp > 0 && tmp <= max) { sums[tmp]++; } } } long ans = 0; for (int i = 1; i <= max; i++) { ans += counts[i] * sums[i]; } System.out.println(ans); } }