結果
問題 | No.1002 Twotone |
ユーザー |
![]() |
提出日時 | 2025-02-28 00:03:25 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 4,397 bytes |
コンパイル時間 | 336 ms |
コンパイル使用メモリ | 82,492 KB |
実行使用メモリ | 318,364 KB |
最終ジャッジ日時 | 2025-02-28 00:03:35 |
合計ジャッジ時間 | 9,773 ms |
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 2 |
other | AC * 1 TLE * 1 -- * 31 |
ソースコード
n,K=map(int,input().split()) e=[[] for i in range(n)] for _ in range(n-1): a,b,c=map(int,input().split()) a-=1 b-=1 c-=1 e[a]+=[(b,c)] e[b]+=[(a,c)] def CD(e): c=[] p=[-1]*n d=[0]*n v=[0]*n u=[0]*n sq=[0] for start in sq: o=[] q=[start] while len(q)>0: s=q[-1] if v[s]==0: v[s]=1 o+=[s] q+=[t for t,_ in e[s] if v[t]==0 and p[t]==-1] else: u[s]=1+sum(u[t] for t,_ in e[s] if v[t]==0 and p[t]==-1) v[s]=0 q.pop() q=[start] while len(q)>0: s=q[-1] if v[s]==0: v[s]=1 for t,_ in e[s]: if v[t]==0 and p[t]==-1: q+=[t] if u[start]-u[t]<u[t]: d[s]+=1 else: d[t]+=1 else: v[s]=0 q.pop() for i in o: if d[i]==0: nc=i u[i]=0 d[i]=0 p[nc]=len(c) c+=[nc] sq+=[t for t,_ in e[nc] if p[t]==-1] return c,p cdc,cdp=CD(e) ans=0 cq=[0]*K cqs=0 sto=[[] for i in range(n)] cc=[0]*K cs=set() v=[0]*n p=[[] for i in range(n)] g=[0]*n for nowc in cdc[::-1]: for start,rc in e[nowc]: if cdp[start]>cdp[nowc]: cc[rc]+=1 cs.add(rc) q=[(start,rc)] while len(q)>0: s,sc=q[-1] if v[s]==0: v[s]=1 sto[start]+=[s] if len(cs)<=2: p[s]=list(cs).copy() while g[s]<len(e[s]): t,tc=e[s][g[s]] if v[t]==0 and cdp[t]>cdp[nowc]: break g[s]+=1 if g[s]<len(e[s]): cc[tc]+=1 if cc[tc]==1: cs.add(tc) q+=[(t,tc)] else: cc[sc]-=1 if cc[sc]==0: cs.discard(sc) q.pop() for start,_ in e[nowc]: if cdp[start]>cdp[nowc]: for i in sto[start]: if len(p[i])==1: x=p[i][0] cq[x]+=1 cqs+=1 for start,_ in e[nowc]: if cdp[start]>cdp[nowc]: for i in sto[start]: if len(p[i])==1: x=p[i][0] cq[x]-=1 cqs-=1 for i in sto[start]: if len(p[i])==1: x=p[i][0] ans+=cqs-cq[x] for start,_ in e[nowc]: if cdp[start]>cdp[nowc]: for i in sto[start]: if len(p[i])==1: x=p[i][0] cq[x]+=1 cqs+=1 for start,_ in e[nowc]: if cdp[start]>cdp[nowc]: for i in sto[start]: if len(p[i])==1: x=p[i][0] cq[x]-=1 cqs-=1 for i in sto[start]: if len(p[i])==2: x1,x2=p[i] ans+=cq[x1]+cq[x2] for i in sto[start]: if len(p[i])==1: x=p[i][0] cq[x]+=1 cqs+=1 for start,_ in e[nowc]: if cdp[start]>cdp[nowc]: for i in sto[start]: if len(p[i])==1: x=p[i][0] cq[x]-=1 cqs-=1 for start,_ in e[nowc]: if cdp[start]>cdp[nowc]: for i in sto[start]: v[i]=0 p[i]=[] g[i]=0 sto[start].clear() cq={} sto=[[] for i in range(n)] cc=[0]*K cs=set() v=[0]*n p=[[] for i in range(n)] g=[0]*n for nowc in cdc[::-1]: for start,rc in e[nowc]: if cdp[start]>cdp[nowc]: cc[rc]+=1 cs.add(rc) q=[(start,rc)] while len(q)>0: s,sc=q[-1] if v[s]==0: v[s]=1 sto[start]+=[s] if len(cs)==2: p[s]=sorted(cs).copy() while g[s]<len(e[s]): t,tc=e[s][g[s]] if v[t]==0 and cdp[t]>cdp[nowc]: break g[s]+=1 if g[s]<len(e[s]): cc[tc]+=1 if cc[tc]==1: cs.add(tc) q+=[(t,tc)] else: cc[sc]-=1 if cc[sc]==0: cs.discard(sc) q.pop() for start,_ in e[nowc]: if cdp[start]>cdp[nowc]: for i in sto[start]: if len(p[i])==2: x=tuple(p[i]) if x not in cq: cq[x]=0 cq[x]+=1 for start,_ in e[nowc]: if cdp[start]>cdp[nowc]: for i in sto[start]: if len(p[i])==2: x=tuple(p[i]) ans+=1 cq[x]-=1 for i in sto[start]: if len(p[i])==2: x=tuple(p[i]) ans+=cq[x] for start,_ in e[nowc]: if cdp[start]>cdp[nowc]: for i in sto[start]: v[i]=0 p[i]=[] g[i]=0 sto[start].clear() print(ans)