n,K=map(int,input().split()) e=[[] for i in range(n)] for _ in range(n-1): a,b,c=map(int,input().split()) a-=1 b-=1 c-=1 e[a]+=[(b,c)] e[b]+=[(a,c)] cdc=[] cdp=[-1]*n v=[0]*n u=[0]*n cdq=[0] for start in cdq: q=[start] while len(q)>0: s=q[-1] if v[s]==0: v[s]=1 q+=[t for t,_ in e[s] if v[t]==0 and cdp[t]==-1] else: u[s]=1+sum([u[t] for t,_ in e[s] if v[t]==0 and cdp[t]==-1]) v[s]=0 q.pop() q=[start] while len(q)>0: s=q[-1] if v[s]==0: if all((u[t] if v[t]==0 else u[start]-u[s])<=u[start]//2 or cdp[t]!=-1 for t,_ in e[s]): o=s v[s]=1 q+=[t for t,_ in e[s] if v[t]==0 and cdp[t]==-1] else: v[s]=0 q.pop() cdp[o]=len(cdc) cdc+=[o] cdq+=[t for t,_ in e[o] if cdp[t]==-1] ans=0 v=[0]*n st1=[{} for i in range(n)] st2=[{} for i in range(n)] st3=[0]*K st4={} st5=[0]*n st6=[0]*K uc=[0]*K for C in cdc[::-1]: st4={} u=set() for root,c in e[C]: if cdp[root]>cdp[C]: st1[root]={} st2[root]={} q=[(root,c)] while len(q)>0: s,sc=q[-1] if v[s]==0: v[s]=1 uc[sc]+=1 if uc[sc]==1: u.add(sc) if len(u)==1: c1=sorted(u)[0] if c1 not in st1[root]: st1[root][c1]=0 st1[root][c1]+=1 st3[c1]+=1 st5[C]+=1 if len(u)==2: c1,c2=sorted(u) c1c2=c1*K+c2 if c1c2 not in st2[root]: st2[root][c1c2]=0 st2[root][c1c2]+=1 if c1c2 not in st4: st4[c1c2]=0 st4[c1c2]+=1 st6[c1]+=1 st6[c2]+=1 for t,tc in e[s]: if v[t]==0 and cdp[t]>cdp[C]: q+=[(t,tc)] else: v[s]=0 uc[sc]-=1 if uc[sc]==0: u.remove(sc) q.pop() ans+=sum(st4[c1c2] for c1c2 in st4)*2 u=set() for root,c in e[C]: if cdp[root]>cdp[C]: for c1 in st1[root]: x=st1[root][c1] st3[c1]-=x st5[C]-=x for c1c2 in st2[root]: x=st2[root][c1c2] st4[c1c2]-=x st6[c1c2//K]-=x st6[c1c2%K]-=x q=[(root,c)] while len(q)>0: s,sc=q[-1] if v[s]==0: v[s]=1 uc[sc]+=1 if uc[sc]==1: u.add(sc) if len(u)==1: c1=sorted(u)[0] ans+=st5[C]-st3[c1] ans+=st6[c1]*2 if len(u)==2: c1,c2=sorted(u) c1c2=c1*K+c2 ans+=st4[c1c2] for t,tc in e[s]: if v[t]==0 and cdp[t]>cdp[C]: q+=[(t,tc)] else: v[s]=0 uc[sc]-=1 if uc[sc]==0: u.remove(sc) q.pop() for c1 in st1[root]: x=st1[root][c1] st3[c1]+=x st5[C]+=x for c1c2 in st2[root]: x=st2[root][c1c2] st4[c1c2]+=x st6[c1c2//K]+=x st6[c1c2%K]+=x u=set() for root,c in e[C]: if cdp[root]>cdp[C]: q=[(root,c)] while len(q)>0: s,sc=q[-1] if v[s]==0: v[s]=1 uc[sc]+=1 if uc[sc]==1: u.add(sc) if len(u)==1: c1=sorted(u)[0] st3[c1]-=1 st5[C]-=1 if len(u)==2: c1,c2=sorted(u) c1c2=c1*K+c2 st6[c1]-=1 st6[c2]-=1 for t,tc in e[s]: if v[t]==0 and cdp[t]>cdp[C]: q+=[(t,tc)] else: v[s]=0 uc[sc]-=1 if uc[sc]==0: u.remove(sc) q.pop() print(ans//2)