#include #include #include #include #include #include typedef unsigned int uint; typedef unsigned long long ull; constexpr uint mod{998244353}; constexpr uint power(uint x,uint y) { uint s{1}; while(y>0) { if(y&1) { s=ull(s)*x%mod; } x=ull(x)*x%mod; y>>=1; } return s; } constexpr uint plus(const uint &x,const uint &y) { if(x+y>=mod) { return x+y-mod; } return x+y; } constexpr uint minus(const uint &x,const uint &y) { if(x { coefficient(const std::size_t &n):std::vector(n) {} coefficient &operator+=(const coefficient &x) { assert(size()==x.size()); for(int i=0;i!=size();i++) { add((*this)[i],x[i]); } return *this; } coefficient &operator-=(const coefficient &x) { assert(size()==x.size()); for(int i=0;i!=size();i++) { sub((*this)[i],x[i]); } return *this; } coefficient &operator*=(const uint &x) { for(int i=0;i!=size();i++) { (*this)[i]=ull(x)*(*this)[i]%mod; } return *this; } coefficient operator*(const uint &x) const { return coefficient(*this)*=x; } coefficient &operator/=(const uint &x) { return operator*=(power(x,mod-2)); } }; bool eliminate(std::vector> &mat,std::vector &res) { int n{int(mat.size())},m{int(mat[0].size())-1}; if(n> deg(n); std::array tot({0,0}); for(int i=1;i<=n-1;i++) { int u,v,l; scanf("%d%d%d",&u,&v,&l); ++deg[u-1][l-1]; ++deg[v-1][l-1]; ++tot[l-1]; } int all{tot[0]+tot[1]*2}; int free{(n*(n-1)>>1)-(n-2)}; uint delta{uint(ull(all)*free%mod*power(n,mod-2)%mod)}; std::vector> dp(tot[0]+2,std::vector(tot[1]+1,coefficient(tot[1]+2))); for(int i=0;i<=tot[1];i++) { dp[0][i][i]=1; } for(int i=0;i<=tot[0];i++) { for(int j=0;j<=tot[1];j++) { int rem{n-1-i-j}; dp[i+1][j]+=dp[i][j]*((ull(all)*free+ull(mod-i-(j<<1))*(rem+1)+ull(mod-(tot[0]-i)-(tot[1]-j<<1))*(free-rem))%mod); if(i!=0) { dp[i+1][j]+=dp[i-1][j]*(ull(mod-i)*(free-rem-1)%mod); } if(j!=0) { dp[i+1][j]+=dp[i][j-1]*(ull(mod-(j<<1))*(free-rem-1)%mod); } if(j!=tot[1]) { dp[i+1][j]+=dp[i][j+1]*(ull(mod-(tot[1]-j<<1))*rem%mod); } sub(dp[i+1][j][tot[1]+1],delta); if(i!=tot[0]) { dp[i+1][j]/=ull(tot[0]-i)*rem%mod; } } } std::vector> mat; for(int i=0;i res(tot[1]+1); assert(eliminate(mat,res)); std::vector> dp2(tot[0]+1,std::vector(tot[1]+1)); for(int i=0;i<=tot[0];i++) { for(int j=0;j<=tot[1];j++) { for(int k=0;k<=tot[1];k++) { dp2[i][j]=(dp2[i][j]+ull(dp[i][j][k])*res[k])%mod; } add(dp2[i][j],dp[i][j][tot[1]+1]); } } uint ans{0}; for(int i=0;i!=n;i++) { add(ans,dp2[deg[i][0]][deg[i][1]]); } if(tot[0]!=0) { ans=(ans+ull(mod-tot[0])*dp2[1][0])%mod; } if(tot[1]!=0) { ans=(ans+ull(mod-tot[1])*dp2[0][1])%mod; } printf("%u\n",ans); return 0; } /* for all 0<=x<=tot[0] and 0<=y<=tot[1], if x = tot[0] and y = tot[1], let f(x,y) = 0, otherwise, let z = n-1-x-y, and f(x,y) = x/all * ((z+1)/free*f(x,y) + (free-z-1)/free*f(x-1,y)) + (tot[0]-x)/all * (z/free*f(x+1,y) + (free-z)/free*f(x,y)) + 2*y/all * ((z+1)/free*f(x,y) + (free-z-1)/free*f(x,y-1)) + 2*(tot[1]-y)/all * (z/free*f(x,y+1) + (free-z)/free*f(x,y)) + 1/n (all*free-(x+2*y)*(z+1)-(all-x-2*y)*(free-z))*f(x,y) - x * (free-z-1) * f(x-1,y) - (tot[0]-x) * z * f(x+1,y) - 2 * y * (free-z-1) * f(x,y-1) - 2 * (tot[1]-y) * z * f(x,y+1) - all*free/n = 0 */