結果
問題 | No.1002 Twotone |
ユーザー |
![]() |
提出日時 | 2025-01-24 11:32:59 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 3,338 bytes |
コンパイル時間 | 1,413 ms |
コンパイル使用メモリ | 82,048 KB |
実行使用メモリ | 419,852 KB |
最終ジャッジ日時 | 2025-01-24 11:35:10 |
合計ジャッジ時間 | 121,941 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 15 TLE * 18 |
ソースコード
n,K=map(int,input().split()) e=[[] for i in range(n)] for _ in range(n-1): a,b,c=map(int,input().split()) a-=1 b-=1 c-=1 e[a]+=[(b,c)] e[b]+=[(a,c)] cdc=[] cdp=[-1]*n v=[0]*n u=[0]*n cdq=[0] for start in cdq: q=[start] while len(q)>0: s=q[-1] if v[s]==0: v[s]=1 q+=[t for t,_ in e[s] if v[t]==0 and cdp[t]==-1] else: u[s]=1+sum([u[t] for t,_ in e[s] if v[t]==0 and cdp[t]==-1]) v[s]=0 q.pop() q=[start] while len(q)>0: s=q[-1] if v[s]==0: if all((u[t] if v[t]==0 else u[start]-u[s])<=u[start]//2 or cdp[t]!=-1 for t,_ in e[s]): o=s v[s]=1 q+=[t for t,_ in e[s] if v[t]==0 and cdp[t]==-1] else: v[s]=0 q.pop() cdp[o]=len(cdc) cdc+=[o] cdq+=[t for t,_ in e[o] if cdp[t]==-1] ans=0 v=[0]*n st1=[{} for i in range(n)] st2=[{} for i in range(n)] st3=[{} for i in range(n)] st4=[{} for i in range(n)] st5=[0]*n st6=[{} for i in range(n)] uc=[0]*K for C in cdc[::-1]: u=set() for root,c in e[C]: if cdp[root]>cdp[C]: st1[root]={} st2[root]={} q=[(root,c)] while len(q)>0: s,sc=q[-1] if v[s]==0: v[s]=1 uc[sc]+=1 if uc[sc]==1: u.add(sc) if len(u)==1: c1=sorted(u)[0] if c1 not in st1[root]: st1[root][c1]=0 st1[root][c1]+=1 if c1 not in st3[C]: st3[C][c1]=0 st3[C][c1]+=1 st5[C]+=1 if len(u)==2: c1,c2=sorted(u) c1c2=c1*K+c2 if c1c2 not in st2[root]: st2[root][c1c2]=0 st2[root][c1c2]+=1 if c1c2 not in st4[C]: st4[C][c1c2]=0 st4[C][c1c2]+=1 if c1 not in st6[C]: st6[C][c1]=0 st6[C][c1]+=1 if c2 not in st6[C]: st6[C][c2]=0 st6[C][c2]+=1 for t,tc in e[s]: if v[t]==0 and cdp[t]>cdp[C]: q+=[(t,tc)] else: v[s]=0 uc[sc]-=1 if uc[sc]==0: u.remove(sc) q.pop() ans+=sum(st4[C][c1c2] for c1c2 in st4[C])*2 u=set() for root,c in e[C]: if cdp[root]>cdp[C]: for c1 in st1[root]: x=st1[root][c1] st3[C][c1]-=x st5[C]-=x for c1c2 in st2[root]: x=st2[root][c1c2] st4[C][c1c2]-=x st6[C][c1c2//K]-=x st6[C][c1c2%K]-=x q=[(root,c)] while len(q)>0: s,sc=q[-1] if v[s]==0: v[s]=1 uc[sc]+=1 if uc[sc]==1: u.add(sc) if len(u)==1: c1=sorted(u)[0] ans+=st5[C]-st3[C][c1] ans+=st6[C][c1]*2 if c1 in st6[C] else 0 if len(u)==2: c1,c2=sorted(u) c1c2=c1*K+c2 ans+=st4[C][c1c2] for t,tc in e[s]: if v[t]==0 and cdp[t]>cdp[C]: q+=[(t,tc)] else: v[s]=0 uc[sc]-=1 if uc[sc]==0: u.remove(sc) q.pop() for c1 in st1[root]: x=st1[root][c1] st3[C][c1]+=x st5[C]+=x for c1c2 in st2[root]: x=st2[root][c1c2] st4[C][c1c2]+=x st6[C][c1c2//K]+=x st6[C][c1c2%K]+=x print(ans//2)