#include #define rep(i,n) for(int i=0;i<(n);i++) using namespace std; using lint=long long; class union_find{ int n; vector p; public: union_find(int n):n(n),p(n,-1){} int find(int u){ return p[u]<0?u:p[u]=find(p[u]); } void unite(int u,int v){ u=find(u); v=find(v); if(u!=v){ if(p[v]> L(n); rep(u,n) L[U1.find(u)].emplace_back(U2.find(u)); vector cnt(n); rep(u,n){ sort(L[u].begin(),L[u].end()); L[u].erase(unique(L[u].begin(),L[u].end()),L[u].end()); for(int v:L[u]) cnt[u]+=U2.size(v); } lint ans=0; rep(u,n) ans+=cnt[U1.find(u)]-1; printf("%lld\n",ans); return 0; }