import sys input = lambda: sys.stdin.readline().rstrip() mi = lambda: map(int,input().split()) li = lambda: list(mi()) from collections import deque n = int(input()) g = [[] for i in range(n)] for i in range(n-1): a,b,c = mi() a,b = a-1,b-1 g[a].append((b,c)) g[b].append((a,c)) siz,x = [0]*n,[0]*n st = deque([0]) now = 0 while st: cur = st.pop() if siz[cur]: for to,weight in g[cur]: siz[cur] += siz[to] else: st.append(cur) par_idx = -1 for i in range(len(g[cur])): to,weight = g[cur][i] if siz[to]: par_idx = i else: x[to] = x[cur] ^ weight st.append(to) if par_idx != -1: g[cur][-1],g[cur][par_idx] = g[cur][par_idx],g[cur][-1] g[cur].pop() siz[cur] = 1 sorted_x = sorted(set(x)) for i in range(n): ok,ng = 0,len(sorted_x) while ng-ok > 1: m = (ok+ng) >> 1 if sorted_x[m] <= x[i]: ok = m else: ng = m x[i] = ok y,cnt,tmp_y,tmp_cnt = [0]*n,[0]*n,[-1]*n,[-1]*n ans = n*(n-1) p20 = 1<<20 mask = p20-1 st.append(0) while st: tmp = st.pop() cur,step = tmp&mask,tmp//p20 if step == 0: ans -= y[x[cur]] ans -= siz[cur] * cnt[x[cur]] tmp_cnt[cur] = cnt[x[cur]] tmp_y[cur] = y[x[cur]] if step < len(g[cur]): cnt[x[cur]] = 1 y[x[cur]] = n - siz[g[cur][step][0]] st.append(cur + p20*(step+1)) st.append(g[cur][step][0]) else: cnt[x[cur]] = tmp_cnt[cur] + 1 y[x[cur]] = tmp_y[cur] + siz[cur] print(ans)