結果
問題 |
No.1784 Not a star yet...
|
ユーザー |
![]() |
提出日時 | 2025-06-05 20:37:22 |
言語 | C++23 (gcc 13.3.0 + boost 1.87.0) |
結果 |
AC
|
実行時間 | 48 ms / 2,000 ms |
コード長 | 5,399 bytes |
コンパイル時間 | 1,853 ms |
コンパイル使用メモリ | 119,300 KB |
実行使用メモリ | 12,288 KB |
最終ジャッジ日時 | 2025-06-05 20:37:28 |
合計ジャッジ時間 | 5,115 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 61 |
コンパイルメッセージ
main.cpp: In function ‘int main()’: main.cpp:111:10: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result] 111 | scanf("%d",&n); | ~~~~~^~~~~~~~~ main.cpp:116:14: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result] 116 | scanf("%d%d%d",&u,&v,&l); | ~~~~~^~~~~~~~~~~~~~~~~~~
ソースコード
#include<stdio.h> #include<iostream> #include<algorithm> #include<cassert> #include<vector> #include<array> typedef unsigned int uint; typedef unsigned long long ull; constexpr uint mod{998244353}; constexpr uint power(uint x,uint y) { uint s{1}; while(y>0) { if(y&1) { s=ull(s)*x%mod; } x=ull(x)*x%mod; y>>=1; } return s; } constexpr uint plus(const uint &x,const uint &y) { if(x+y>=mod) { return x+y-mod; } return x+y; } constexpr uint minus(const uint &x,const uint &y) { if(x<y) { return x-y+mod; } return x-y; } constexpr void add(uint &x,const uint &y) { x=plus(x,y); } constexpr void sub(uint &x,const uint &y) { x=minus(x,y); } struct coefficient:std::vector<uint> { coefficient(const std::size_t &n):std::vector<uint>(n) {} coefficient &operator+=(const coefficient &x) { assert(size()==x.size()); for(int i=0;i!=size();i++) { add((*this)[i],x[i]); } return *this; } coefficient &operator-=(const coefficient &x) { assert(size()==x.size()); for(int i=0;i!=size();i++) { sub((*this)[i],x[i]); } return *this; } coefficient &operator*=(const uint &x) { for(int i=0;i!=size();i++) { (*this)[i]=ull(x)*(*this)[i]%mod; } return *this; } coefficient operator*(const uint &x) const { return coefficient(*this)*=x; } coefficient &operator/=(const uint &x) { return operator*=(power(x,mod-2)); } }; bool eliminate(std::vector<std::vector<uint>> &mat,std::vector<uint> &res) { int n{int(mat.size())},m{int(mat[0].size())-1}; if(n<m) { return false; } for(int i=0;i!=m;i++) { int _i{-1}; for(int j=i;j!=n;j++) { if(mat[j][i]) { _i=j; break; } } if(_i==-1) { return false; } if(i!=_i) { mat[i].swap(mat[_i]); } uint inv{power(mat[i][i],mod-2)};; for(int j=i+1;j!=n;j++) { uint tmp{uint(ull(mod-mat[j][i])*inv%mod)}; for(int k=i;k!=m+1;k++) { mat[j][k]=(mat[j][k]+ull(tmp)*mat[i][k])%mod; } } } for(int i=m;i!=n;i++) { if(mat[i][m]!=0) { return false; } } res.resize(m); for(int i=m-1;i!=-1;i--) { res[i]=ull(mod-mat[i][m])*power(mat[i][i],mod-2)%mod; for(int j=i-1;j!=-1;j--) { mat[j][m]=(mat[j][m]+ull(res[i])*mat[j][i])%mod; } } return true; } int main() { int n; scanf("%d",&n); std::vector<std::array<int,2>> deg(n); std::array<int,2> tot({0,0}); for(int i=1;i<=n-1;i++) { int u,v,l; scanf("%d%d%d",&u,&v,&l); ++deg[u-1][l-1]; ++deg[v-1][l-1]; ++tot[l-1]; } int all{tot[0]+tot[1]*2}; int free{(n*(n-1)>>1)-(n-2)}; uint delta{uint(ull(all)*free%mod*power(n,mod-2)%mod)}; std::vector<std::vector<coefficient>> dp(tot[0]+2,std::vector<coefficient>(tot[1]+1,coefficient(tot[1]+2))); for(int i=0;i<=tot[1];i++) { dp[0][i][i]=1; } for(int i=0;i<=tot[0];i++) { for(int j=0;j<=tot[1];j++) { int rem{n-1-i-j}; dp[i+1][j]+=dp[i][j]*((ull(all)*free+ull(mod-i-(j<<1))*(rem+1)+ull(mod-(tot[0]-i)-(tot[1]-j<<1))*(free-rem))%mod); if(i!=0) { dp[i+1][j]+=dp[i-1][j]*(ull(mod-i)*(free-rem-1)%mod); } if(j!=0) { dp[i+1][j]+=dp[i][j-1]*(ull(mod-(j<<1))*(free-rem-1)%mod); } if(j!=tot[1]) { dp[i+1][j]+=dp[i][j+1]*(ull(mod-(tot[1]-j<<1))*rem%mod); } sub(dp[i+1][j][tot[1]+1],delta); if(i!=tot[0]) { dp[i+1][j]/=ull(tot[0]-i)*rem%mod; } } } std::vector<std::vector<uint>> mat; for(int i=0;i<tot[1];i++) { mat.emplace_back(dp[tot[0]+1][i]); } mat.emplace_back(dp[tot[0]][tot[1]]); std::vector<uint> res(tot[1]+1); assert(eliminate(mat,res)); std::vector<std::vector<uint>> dp2(tot[0]+1,std::vector<uint>(tot[1]+1)); for(int i=0;i<=tot[0];i++) { for(int j=0;j<=tot[1];j++) { for(int k=0;k<=tot[1];k++) { dp2[i][j]=(dp2[i][j]+ull(dp[i][j][k])*res[k])%mod; } add(dp2[i][j],dp[i][j][tot[1]+1]); } } uint ans{0}; for(int i=0;i!=n;i++) { add(ans,dp2[deg[i][0]][deg[i][1]]); } if(tot[0]!=0) { ans=(ans+ull(mod-tot[0])*dp2[1][0])%mod; } if(tot[1]!=0) { ans=(ans+ull(mod-tot[1])*dp2[0][1])%mod; } printf("%u\n",ans); return 0; } /* for all 0<=x<=tot[0] and 0<=y<=tot[1], if x = tot[0] and y = tot[1], let f(x,y) = 0, otherwise, let z = n-1-x-y, and f(x,y) = x/all * ((z+1)/free*f(x,y) + (free-z-1)/free*f(x-1,y)) + (tot[0]-x)/all * (z/free*f(x+1,y) + (free-z)/free*f(x,y)) + 2*y/all * ((z+1)/free*f(x,y) + (free-z-1)/free*f(x,y-1)) + 2*(tot[1]-y)/all * (z/free*f(x,y+1) + (free-z)/free*f(x,y)) + 1/n (all*free-(x+2*y)*(z+1)-(all-x-2*y)*(free-z))*f(x,y) - x * (free-z-1) * f(x-1,y) - (tot[0]-x) * z * f(x+1,y) - 2 * y * (free-z-1) * f(x,y-1) - 2 * (tot[1]-y) * z * f(x,y+1) - all*free/n = 0 */