n,K=map(int,input().split()) e=[[] for i in range(n)] for _ in range(n-1): a,b,c=map(int,input().split()) a-=1 b-=1 c-=1 e[a]+=[(b,c)] e[b]+=[(a,c)] def CD(e): c=[] p=[-1]*n d=[0]*n v=[0]*n u=[0]*n sq=[0] for start in sq: o=[] q=[start] while len(q)>0: s=q[-1] if v[s]==0: v[s]=1 o+=[s] q+=[t for t,_ in e[s] if v[t]==0 and p[t]==-1] else: u[s]=1+sum(u[t] for t,_ in e[s] if v[t]==0 and p[t]==-1) v[s]=0 q.pop() q=[start] while len(q)>0: s=q[-1] if v[s]==0: v[s]=1 for t,_ in e[s]: if v[t]==0 and p[t]==-1: q+=[t] if u[start]-u[t]cdp[nowc]: cc[rc]+=1 cs.add(rc) q=[(start,rc)] while len(q)>0: s,sc=q[-1] if v[s]==0: v[s]=1 sto[start]+=[s] if len(cs)<=2: p[s]=list(cs).copy() while g[s]cdp[nowc]: break g[s]+=1 if g[s]cdp[nowc]: for i in sto[start]: if len(p[i])==1: x=p[i][0] cq1[x]+=1 cqs+=1 if len(p[i])==2: x1,x2=p[i] if x2 not in cq2[x1]: cq2[x1][x2]=0 cq2[x1][x2]+=1 for start,_ in e[nowc]: if cdp[start]>cdp[nowc]: for i in sto[start]: if len(p[i])==1: x=p[i][0] cq1[x]-=1 cqs-=1 if len(p[i])==2: x1,x2=p[i] cq2[x1][x2]-=1 for i in sto[start]: if len(p[i])==1: x=p[i][0] a1+=cqs-cq1[x] if len(p[i])==2: x1,x2=p[i] a2+=cq1[x1]+cq1[x2] a3+=cq2[x1][x2] a4+=1 for i in sto[start]: if len(p[i])==1: x=p[i][0] cq1[x]+=1 cqs+=1 if len(p[i])==2: x1,x2=p[i] cq2[x1][x2]+=1 for start,_ in e[nowc]: if cdp[start]>cdp[nowc]: for i in sto[start]: if len(p[i])==1: x=p[i][0] cq1[x]-=1 cqs-=1 if len(p[i])==2: x1,x2=p[i] cq2[x1][x2]-=1 for start,_ in e[nowc]: if cdp[start]>cdp[nowc]: for i in sto[start]: v[i]=0 p[i]=[] g[i]=0 sto[start].clear() print(a1//2+a2+a3//2+a4)