import java.util.Scanner; class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int d = sc.nextInt(); int[] cnt = new int[n * n * 2 + 1]; int[] cnt2 = new int[n * n * 2 + 1]; for (int i = 1; i <= n; i++) { for (int j = 1; j <= n; j++) { cnt[i * i + j * j]++; } } for (int i = 1; i <= n; i++) { for (int j = 1; j <= n; j++) { int pos = i * i + d - j * j; if (pos > 0 && pos <= n * n * 2) { cnt2[pos]++; } } } long ans = 0; for (int i = 0; i <= n * n * 2; i++) { ans += cnt[i] * cnt2[i]; } System.out.println(ans); sc.close(); } }