import sys from collections import deque, defaultdict from itertools import groupby from random import randint M = 10 ** 18 pool = {} lcg_gen = None def f(e): if e not in pool: pool[e] = (randint(0, M - 1), randint(0, M - 1)) return pool[e] else: return pool[e] def solve(N, E): par = [-1] * N order = [] q = deque([0]) par[0] = -2 head = 0 while head < len(order) + 1: if head == len(order): if not q: break i = q.popleft() order.append(i) else: i = order[head] for j in E[i]: if j == par[i]: continue par[j] = i q.append(j) head += 1 up = [(0, 0)] * N up2 = [0] * N for i in reversed(order): children_data = [] hash_sum = (0, 0) for j in E[i]: if j == par[i]: continue hash_sum = ((hash_sum[0] + up[j][0]) % M, (hash_sum[1] + up[j][1]) % M) children_data.append((up[j], up2[j])) up[i] = f(hash_sum) children_data.sort() unique_children_data = [k for k, g in groupby(children_data)] up2[i] = sum(e[1] for e in unique_children_data) + 1 down = [(0, 0)] * N down2 = [0] * N dp2 = [0] * N for i in order: mp = defaultdict(int) child_indices = [] children_data_map = {} if par[i] != -2: mp[(down[i], down2[i])] += 1 for j in E[i]: if j == par[i]: continue child_data = (up[j], up2[j]) mp[child_data] += 1 child_indices.append(j) children_data_map[j] = child_data sum_dp2 = sum(data[1] for data in mp) + 1 dp2[i] = sum_dp2 pre = down[i] for j in child_indices: down[j] = pre h_j = children_data_map[j][0] pre = ((pre[0] + h_j[0]) % M, (pre[1] + h_j[1]) % M) suf = (0, 0) for j in reversed(child_indices): down[j] = ((down[j][0] + suf[0]) % M, (down[j][1] + suf[1]) % M) down[j] = f(down[j]) child_data = children_data_map[j] if mp[child_data] == 1: down2[j] = sum_dp2 - child_data[1] else: down2[j] = sum_dp2 h_j = children_data_map[j][0] suf = ((suf[0] + h_j[0]) % M, (suf[1] + h_j[1]) % M) final_set = {} for i in range(N): tmp = (0, 0) if par[i] != -2: tmp = ((tmp[0] + down[i][0]) % M, (tmp[1] + down[i][1]) % M) for j in E[i]: if j == par[i]: continue tmp = ((tmp[0] + up[j][0]) % M, (tmp[1] + up[j][1]) % M) final_set[tmp] = dp2[i] ans = sum(final_set.values()) return ans N = int(input()) E = [[] for _ in range(N)] for _ in range(N - 1): u, v = map(int, input().split()) E[u - 1].append(v - 1) E[v - 1].append(u - 1) print(solve(N, E))