import java.util.*; public class Main { static ArrayList> graph = new ArrayList<>(); public static void main(String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); for (int i = 0; i < n; i++) { graph.add(new ArrayList<>()); } for (int i = 1; i < n; i++) { graph.get(sc.nextInt()).add(i); } ArrayList list = new ArrayList<>(); list.add(Integer.MIN_VALUE); list.add(Integer.MAX_VALUE); System.out.println(search(0, list)); } static int search(int idx, ArrayList list) { int left = 0; int right = list.size(); while (right - left > 1) { int m = (left + right) / 2; if (list.get(m) >= idx) { right = m; } else { left = m; } } list.add(right, idx); int sum = right - 1; for (int x : graph.get(idx)) { sum += search(x, list); } list.remove(right); return sum; } }