n,K=map(int,input().split()) e=[[] for i in range(n)] for _ in range(n-1): a,b,c=map(int,input().split()) a-=1 b-=1 c-=1 e[a]+=[(b,c)] e[b]+=[(a,c)] def CD(e): c=[] p=[-1]*n d=[0]*n v=[0]*n u=[0]*n w=[0]*n w[0]=n sq=[0] for start in sq: o=[] q=[start] while len(q)>0: s=q[-1] if v[s]==0: v[s]=1 o+=[s] for t,_ in e[s]: if v[t]==0 and p[t]==-1: q+=[t] d[t]=d[s]+1 else: u[s]=1+sum(u[t] for t,_ in e[s] if v[t]==0 and p[t]==-1) for t,_ in e[s]: f=1 if p[t]==-1: if d[t]>d[s]: f&=w[start]-u[t]>=u[t] else: f&=u[s]>=w[start]-u[s] if f: nc=s v[s]=0 q.pop() p[nc]=len(c) c+=[nc] for i in o: u[i]=0 d[i]=0 for t,_ in e[nc]: if p[t]==-1: sq+=[t] if d[t]>d[nc]: w[t]=u[t] else: w[t]=w[start]-u[nc] return c,p cdc,cdp=CD(e) a1=0 a2=0 a3=0 a4=0 cq1=[0]*K cqs=0 cq2=[{} for i in range(K)] sto=[[] for i in range(n)] cc=[0]*K cs=set() v=[0]*n p=[[] for i in range(n)] g=[0]*n for nowc in cdc[::-1]: for start,rc in e[nowc]: if cdp[start]>cdp[nowc]: cc[rc]+=1 cs.add(rc) q=[(start,rc)] while len(q)>0: s,sc=q[-1] if v[s]==0: v[s]=1 sto[start]+=[s] if len(cs)<=2: p[s]=list(cs).copy() while g[s]cdp[nowc] and len(cs)<=3: break g[s]+=1 if g[s]cdp[nowc]: for i in sto[start]: if len(p[i])==1: x=p[i][0] cq1[x]+=1 cqs+=1 if len(p[i])==2: x1,x2=p[i] if x2 not in cq2[x1]: cq2[x1][x2]=0 cq2[x1][x2]+=1 for start,_ in e[nowc]: if cdp[start]>cdp[nowc]: for i in sto[start]: if len(p[i])==1: x=p[i][0] cq1[x]-=1 cqs-=1 if len(p[i])==2: x1,x2=p[i] cq2[x1][x2]-=1 for i in sto[start]: if len(p[i])==1: x=p[i][0] a1+=cqs-cq1[x] if len(p[i])==2: x1,x2=p[i] a2+=cq1[x1]+cq1[x2] a3+=cq2[x1][x2] a4+=1 for i in sto[start]: if len(p[i])==1: x=p[i][0] cq1[x]+=1 cqs+=1 if len(p[i])==2: x1,x2=p[i] cq2[x1][x2]+=1 for start,_ in e[nowc]: if cdp[start]>cdp[nowc]: for i in sto[start]: if len(p[i])==1: x=p[i][0] cq1[x]-=1 cqs-=1 if len(p[i])==2: x1,x2=p[i] cq2[x1][x2]-=1 v[i]=0 p[i]=[] g[i]=0 sto[start].clear() print(a1//2+a2+a3//2+a4)