結果

問題 No.1784 Not a star yet...
ユーザー rqoi031
提出日時 2025-06-05 20:37:22
言語 C++23
(gcc 13.3.0 + boost 1.87.0)
結果
AC  
実行時間 48 ms / 2,000 ms
コード長 5,399 bytes
コンパイル時間 1,853 ms
コンパイル使用メモリ 119,300 KB
実行使用メモリ 12,288 KB
最終ジャッジ日時 2025-06-05 20:37:28
合計ジャッジ時間 5,115 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 61
権限があれば一括ダウンロードができます
コンパイルメッセージ
main.cpp: In function ‘int main()’:
main.cpp:111:10: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
  111 |     scanf("%d",&n);
      |     ~~~~~^~~~~~~~~
main.cpp:116:14: warning: ignoring return value of ‘int scanf(const char*, ...)’ declared with attribute ‘warn_unused_result’ [-Wunused-result]
  116 |         scanf("%d%d%d",&u,&v,&l);
      |         ~~~~~^~~~~~~~~~~~~~~~~~~

ソースコード

diff #

#include<stdio.h>
#include<iostream>
#include<algorithm>
#include<cassert>
#include<vector>
#include<array>
typedef unsigned int uint;
typedef unsigned long long ull;
constexpr uint mod{998244353};
constexpr uint power(uint x,uint y) {
    uint s{1};
    while(y>0) {
        if(y&1) {
            s=ull(s)*x%mod;
        }
        x=ull(x)*x%mod;
        y>>=1;
    }
    return s;
}
constexpr uint plus(const uint &x,const uint &y) {
    if(x+y>=mod) {
        return x+y-mod;
    }
    return x+y;
}
constexpr uint minus(const uint &x,const uint &y) {
    if(x<y) {
        return x-y+mod;
    }
    return x-y;
}
constexpr void add(uint &x,const uint &y) {
    x=plus(x,y);
}
constexpr void sub(uint &x,const uint &y) {
    x=minus(x,y);
}
struct coefficient:std::vector<uint> {
    coefficient(const std::size_t &n):std::vector<uint>(n) {}
    coefficient &operator+=(const coefficient &x) {
        assert(size()==x.size());
        for(int i=0;i!=size();i++) {
            add((*this)[i],x[i]);
        }
        return *this;
    }
    coefficient &operator-=(const coefficient &x) {
        assert(size()==x.size());
        for(int i=0;i!=size();i++) {
            sub((*this)[i],x[i]);
        }
        return *this;
    }
    coefficient &operator*=(const uint &x) {
        for(int i=0;i!=size();i++) {
            (*this)[i]=ull(x)*(*this)[i]%mod;
        }
        return *this;
    }
    coefficient operator*(const uint &x) const {
        return coefficient(*this)*=x;
    }
    coefficient &operator/=(const uint &x) {
        return operator*=(power(x,mod-2));
    }
};
bool eliminate(std::vector<std::vector<uint>> &mat,std::vector<uint> &res) {
    int n{int(mat.size())},m{int(mat[0].size())-1};
    if(n<m) {
        return false;
    }
    for(int i=0;i!=m;i++) {
        int _i{-1};
        for(int j=i;j!=n;j++) {
            if(mat[j][i]) {
                _i=j;
                break;
            }
        }
        if(_i==-1) {
            return false;
        }
        if(i!=_i) {
            mat[i].swap(mat[_i]);
        }
        uint inv{power(mat[i][i],mod-2)};;
        for(int j=i+1;j!=n;j++) {
            uint tmp{uint(ull(mod-mat[j][i])*inv%mod)};
            for(int k=i;k!=m+1;k++) {
                mat[j][k]=(mat[j][k]+ull(tmp)*mat[i][k])%mod;
            }
        }
    }
    for(int i=m;i!=n;i++) {
        if(mat[i][m]!=0) {
            return false;
        }
    }
    res.resize(m);
    for(int i=m-1;i!=-1;i--) {
        res[i]=ull(mod-mat[i][m])*power(mat[i][i],mod-2)%mod;
        for(int j=i-1;j!=-1;j--) {
            mat[j][m]=(mat[j][m]+ull(res[i])*mat[j][i])%mod;
        }
    }
    return true;
}
int main() {
    int n;
    scanf("%d",&n);
    std::vector<std::array<int,2>> deg(n);
    std::array<int,2> tot({0,0});
    for(int i=1;i<=n-1;i++) {
        int u,v,l;
        scanf("%d%d%d",&u,&v,&l);
        ++deg[u-1][l-1];
        ++deg[v-1][l-1];
        ++tot[l-1];
    }
    int all{tot[0]+tot[1]*2};
    int free{(n*(n-1)>>1)-(n-2)};
    uint delta{uint(ull(all)*free%mod*power(n,mod-2)%mod)};
    std::vector<std::vector<coefficient>> dp(tot[0]+2,std::vector<coefficient>(tot[1]+1,coefficient(tot[1]+2)));
    for(int i=0;i<=tot[1];i++) {
        dp[0][i][i]=1;
    }
    for(int i=0;i<=tot[0];i++) {
        for(int j=0;j<=tot[1];j++) {
            int rem{n-1-i-j};
            dp[i+1][j]+=dp[i][j]*((ull(all)*free+ull(mod-i-(j<<1))*(rem+1)+ull(mod-(tot[0]-i)-(tot[1]-j<<1))*(free-rem))%mod);
            if(i!=0) {
                dp[i+1][j]+=dp[i-1][j]*(ull(mod-i)*(free-rem-1)%mod);
            }
            if(j!=0) {
                dp[i+1][j]+=dp[i][j-1]*(ull(mod-(j<<1))*(free-rem-1)%mod);
            }
            if(j!=tot[1]) {
                dp[i+1][j]+=dp[i][j+1]*(ull(mod-(tot[1]-j<<1))*rem%mod);
            }
            sub(dp[i+1][j][tot[1]+1],delta);
            if(i!=tot[0]) {
                dp[i+1][j]/=ull(tot[0]-i)*rem%mod;
            }
        }
    }
    std::vector<std::vector<uint>> mat;
    for(int i=0;i<tot[1];i++) {
        mat.emplace_back(dp[tot[0]+1][i]);
    }
    mat.emplace_back(dp[tot[0]][tot[1]]);
    std::vector<uint> res(tot[1]+1);
    assert(eliminate(mat,res));
    std::vector<std::vector<uint>> dp2(tot[0]+1,std::vector<uint>(tot[1]+1));
    for(int i=0;i<=tot[0];i++) {
        for(int j=0;j<=tot[1];j++) {
            for(int k=0;k<=tot[1];k++) {
                dp2[i][j]=(dp2[i][j]+ull(dp[i][j][k])*res[k])%mod;
            }
            add(dp2[i][j],dp[i][j][tot[1]+1]);
        }
    }
    uint ans{0};
    for(int i=0;i!=n;i++) {
        add(ans,dp2[deg[i][0]][deg[i][1]]);
    }
    if(tot[0]!=0) {
        ans=(ans+ull(mod-tot[0])*dp2[1][0])%mod;
    }
    if(tot[1]!=0) {
        ans=(ans+ull(mod-tot[1])*dp2[0][1])%mod;
    }
    printf("%u\n",ans);
    return 0;
}
/*
for all 0<=x<=tot[0] and 0<=y<=tot[1],
if x = tot[0] and y = tot[1], let f(x,y) = 0,
otherwise, let z = n-1-x-y, and
f(x,y) = x/all * ((z+1)/free*f(x,y) + (free-z-1)/free*f(x-1,y))
       + (tot[0]-x)/all * (z/free*f(x+1,y) + (free-z)/free*f(x,y))
       + 2*y/all * ((z+1)/free*f(x,y) + (free-z-1)/free*f(x,y-1))
       + 2*(tot[1]-y)/all * (z/free*f(x,y+1) + (free-z)/free*f(x,y))
       + 1/n

(all*free-(x+2*y)*(z+1)-(all-x-2*y)*(free-z))*f(x,y)
- x * (free-z-1) * f(x-1,y)
- (tot[0]-x) * z * f(x+1,y)
- 2 * y * (free-z-1) * f(x,y-1)
- 2 * (tot[1]-y) * z * f(x,y+1)
- all*free/n = 0
*/
0