#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define popcount __builtin_popcount using namespace std; typedef long long int ll; typedef pair P; const int MAX_V = 200020; // ツリーのサイズのありうる最大値 int N; // ツリーのサイズ vector

tree[MAX_V]; // ツリーを隣接リスト形式のグラフ構造で表したもの int sizeSubtree[MAX_V]; // sizeSubtree[v] := v を根とする部分ツリーのサイズ (分割統治の毎ステップごとに再利用) bool isRemoved[MAX_V]; // isRemoved[v] := v が既に取り除かれたかどうか int whoIsParent[MAX_V]; // whoIsParent[v] := ツリーDP時に v の親が誰だったか // メイン処理 vector centroids; void FindCentroidRecursive(int v, int size, int p = -1) { sizeSubtree[v] = 1; whoIsParent[v] = p; bool isCentroid = true; for (auto pr : tree[v]) { int ch=pr.first; if (ch == p) continue; if (isRemoved[ch]) continue; FindCentroidRecursive(ch, size, v); if (sizeSubtree[ch] > size / 2) isCentroid = false; sizeSubtree[v] += sizeSubtree[ch]; } if (size - sizeSubtree[v] > size / 2) isCentroid = false; if (isCentroid) centroids.push_back(v); } // 初期化 void Init() { for (int i = 0; i < MAX_V; ++i) { isRemoved[i] = false; } } // first: 重心, second: (重心の子ノード, 子部分木のサイズ) からなるベクトル pair > > FindCentroid(int root, int size) { vector > subtrees; centroids.clear(); FindCentroidRecursive(root, size); int center = centroids[0]; for (auto pr : tree[center]) { int ch=pr.first; if (isRemoved[ch]) continue; if (ch == whoIsParent[center]) { subtrees.push_back(make_pair(ch, size - sizeSubtree[center])); } else { subtrees.push_back(make_pair(ch, sizeSubtree[ch])); } } return make_pair(center, subtrees); } map mp1; map mp2; void dfs(int v, int p, vector vc){ //whoIsParent[v] = p; if(vc.size()==1) mp2[vc[0]]++; else if(vc.size()==2) mp1[{vc[0], vc[1]}]++; for(auto pr:tree[v]){ int ch=pr.first; if (ch == p) continue; if (isRemoved[ch]) continue; vector vc2=vc; vc2.push_back(pr.second); sort(vc2.begin(), vc2.end()); vc2.erase(unique(vc2.begin(), vc2.end()), vc2.end()); if(vc2.size()<=2) dfs(ch, v, vc2); } } ll ans; void solve(int root, int size){ if(size<=1) return; pair > > pr=FindCentroid(root, size); int cent=pr.first; mp1.clear(); mp2.clear(); vector vc0; dfs(cent, -1, vc0); for(auto p:mp1) ans+=p.second; auto calc=[&](){ ll ans1=0; for(auto p:mp1){ ans1+=p.second*p.second; int x=p.first.first, y=p.first.second; if(mp2.find(x)!=mp2.end()){ ans1+=p.second*mp2[x]*2; } if(mp2.find(y)!=mp2.end()){ ans1+=p.second*mp2[y]*2; } } ll sum=0; for(auto p:mp2) sum+=p.second; for(auto p:mp2) ans1+=p.second*(sum-p.second); return ans1; }; ll ans1=calc(); isRemoved[cent]=1; for(auto pr:tree[cent]){ int y=pr.first; if(!isRemoved[y]){ mp1.clear(); mp2.clear(); vector vc0(1, pr.second); dfs(y, -1, vc0); ans1-=calc(); } } ans+=ans1/2; for(auto prr:pr.second){ if(!isRemoved[prr.first]) solve(prr.first, prr.second); } } int main() { int n, k; cin>>n>>k; for(int i=0; i