結果
| 問題 | No.1784 Not a star yet... |
| コンテスト | |
| ユーザー |
rqoi031
|
| 提出日時 | 2025-06-05 20:37:22 |
| 言語 | C++23 (gcc 13.3.0 + boost 1.89.0) |
| 結果 |
AC
|
| 実行時間 | 48 ms / 2,000 ms |
| コード長 | 5,399 bytes |
| 記録 | |
| コンパイル時間 | 1,853 ms |
| コンパイル使用メモリ | 119,300 KB |
| 実行使用メモリ | 12,288 KB |
| 最終ジャッジ日時 | 2025-06-05 20:37:28 |
| 合計ジャッジ時間 | 5,115 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 61 |
コンパイルメッセージ
main.cpp: In function ‘int main()’:
main.cpp:111:10: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
111 | scanf("%d",&n);
| ~~~~~^~~~~~~~~
main.cpp:116:14: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
116 | scanf("%d%d%d",&u,&v,&l);
| ~~~~~^~~~~~~~~~~~~~~~~~~
ソースコード
#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<cassert>
#include<vector>
#include<array>
typedef unsigned int uint;
typedef unsigned long long ull;
constexpr uint mod{998244353};
constexpr uint power(uint x,uint y) {
uint s{1};
while(y>0) {
if(y&1) {
s=ull(s)*x%mod;
}
x=ull(x)*x%mod;
y>>=1;
}
return s;
}
constexpr uint plus(const uint &x,const uint &y) {
if(x+y>=mod) {
return x+y-mod;
}
return x+y;
}
constexpr uint minus(const uint &x,const uint &y) {
if(x<y) {
return x-y+mod;
}
return x-y;
}
constexpr void add(uint &x,const uint &y) {
x=plus(x,y);
}
constexpr void sub(uint &x,const uint &y) {
x=minus(x,y);
}
struct coefficient:std::vector<uint> {
coefficient(const std::size_t &n):std::vector<uint>(n) {}
coefficient &operator+=(const coefficient &x) {
assert(size()==x.size());
for(int i=0;i!=size();i++) {
add((*this)[i],x[i]);
}
return *this;
}
coefficient &operator-=(const coefficient &x) {
assert(size()==x.size());
for(int i=0;i!=size();i++) {
sub((*this)[i],x[i]);
}
return *this;
}
coefficient &operator*=(const uint &x) {
for(int i=0;i!=size();i++) {
(*this)[i]=ull(x)*(*this)[i]%mod;
}
return *this;
}
coefficient operator*(const uint &x) const {
return coefficient(*this)*=x;
}
coefficient &operator/=(const uint &x) {
return operator*=(power(x,mod-2));
}
};
bool eliminate(std::vector<std::vector<uint>> &mat,std::vector<uint> &res) {
int n{int(mat.size())},m{int(mat[0].size())-1};
if(n<m) {
return false;
}
for(int i=0;i!=m;i++) {
int _i{-1};
for(int j=i;j!=n;j++) {
if(mat[j][i]) {
_i=j;
break;
}
}
if(_i==-1) {
return false;
}
if(i!=_i) {
mat[i].swap(mat[_i]);
}
uint inv{power(mat[i][i],mod-2)};;
for(int j=i+1;j!=n;j++) {
uint tmp{uint(ull(mod-mat[j][i])*inv%mod)};
for(int k=i;k!=m+1;k++) {
mat[j][k]=(mat[j][k]+ull(tmp)*mat[i][k])%mod;
}
}
}
for(int i=m;i!=n;i++) {
if(mat[i][m]!=0) {
return false;
}
}
res.resize(m);
for(int i=m-1;i!=-1;i--) {
res[i]=ull(mod-mat[i][m])*power(mat[i][i],mod-2)%mod;
for(int j=i-1;j!=-1;j--) {
mat[j][m]=(mat[j][m]+ull(res[i])*mat[j][i])%mod;
}
}
return true;
}
int main() {
int n;
scanf("%d",&n);
std::vector<std::array<int,2>> deg(n);
std::array<int,2> tot({0,0});
for(int i=1;i<=n-1;i++) {
int u,v,l;
scanf("%d%d%d",&u,&v,&l);
++deg[u-1][l-1];
++deg[v-1][l-1];
++tot[l-1];
}
int all{tot[0]+tot[1]*2};
int free{(n*(n-1)>>1)-(n-2)};
uint delta{uint(ull(all)*free%mod*power(n,mod-2)%mod)};
std::vector<std::vector<coefficient>> dp(tot[0]+2,std::vector<coefficient>(tot[1]+1,coefficient(tot[1]+2)));
for(int i=0;i<=tot[1];i++) {
dp[0][i][i]=1;
}
for(int i=0;i<=tot[0];i++) {
for(int j=0;j<=tot[1];j++) {
int rem{n-1-i-j};
dp[i+1][j]+=dp[i][j]*((ull(all)*free+ull(mod-i-(j<<1))*(rem+1)+ull(mod-(tot[0]-i)-(tot[1]-j<<1))*(free-rem))%mod);
if(i!=0) {
dp[i+1][j]+=dp[i-1][j]*(ull(mod-i)*(free-rem-1)%mod);
}
if(j!=0) {
dp[i+1][j]+=dp[i][j-1]*(ull(mod-(j<<1))*(free-rem-1)%mod);
}
if(j!=tot[1]) {
dp[i+1][j]+=dp[i][j+1]*(ull(mod-(tot[1]-j<<1))*rem%mod);
}
sub(dp[i+1][j][tot[1]+1],delta);
if(i!=tot[0]) {
dp[i+1][j]/=ull(tot[0]-i)*rem%mod;
}
}
}
std::vector<std::vector<uint>> mat;
for(int i=0;i<tot[1];i++) {
mat.emplace_back(dp[tot[0]+1][i]);
}
mat.emplace_back(dp[tot[0]][tot[1]]);
std::vector<uint> res(tot[1]+1);
assert(eliminate(mat,res));
std::vector<std::vector<uint>> dp2(tot[0]+1,std::vector<uint>(tot[1]+1));
for(int i=0;i<=tot[0];i++) {
for(int j=0;j<=tot[1];j++) {
for(int k=0;k<=tot[1];k++) {
dp2[i][j]=(dp2[i][j]+ull(dp[i][j][k])*res[k])%mod;
}
add(dp2[i][j],dp[i][j][tot[1]+1]);
}
}
uint ans{0};
for(int i=0;i!=n;i++) {
add(ans,dp2[deg[i][0]][deg[i][1]]);
}
if(tot[0]!=0) {
ans=(ans+ull(mod-tot[0])*dp2[1][0])%mod;
}
if(tot[1]!=0) {
ans=(ans+ull(mod-tot[1])*dp2[0][1])%mod;
}
printf("%u\n",ans);
return 0;
}
/*
for all 0<=x<=tot[0] and 0<=y<=tot[1],
if x = tot[0] and y = tot[1], let f(x,y) = 0,
otherwise, let z = n-1-x-y, and
f(x,y) = x/all * ((z+1)/free*f(x,y) + (free-z-1)/free*f(x-1,y))
+ (tot[0]-x)/all * (z/free*f(x+1,y) + (free-z)/free*f(x,y))
+ 2*y/all * ((z+1)/free*f(x,y) + (free-z-1)/free*f(x,y-1))
+ 2*(tot[1]-y)/all * (z/free*f(x,y+1) + (free-z)/free*f(x,y))
+ 1/n
(all*free-(x+2*y)*(z+1)-(all-x-2*y)*(free-z))*f(x,y)
- x * (free-z-1) * f(x-1,y)
- (tot[0]-x) * z * f(x+1,y)
- 2 * y * (free-z-1) * f(x,y-1)
- 2 * (tot[1]-y) * z * f(x,y+1)
- all*free/n = 0
*/
rqoi031