結果

問題 No.2471 Gemini Tree(Ver.Lapislazuli)
ユーザー Yerin Jung
提出日時 2023-09-24 01:00:55
言語 C++14
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 3,228 bytes
コンパイル時間 1,949 ms
コンパイル使用メモリ 171,540 KB
実行使用メモリ 21,196 KB
最終ジャッジ日時 2024-07-17 00:55:22
合計ジャッジ時間 4,519 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 4 WA * 31
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<bits/stdc++.h>
using namespace std ;
typedef long long ll ;
typedef unsigned long long ull ;
typedef pair < int , int > pii ;
typedef vector<int> vi;
#define fi first
#define se second
mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define all(x) begin(x), end(x)
#define sz(x) (int)(x).size()

const int MOD = 998244353 ;
const int MAXN = 1e5 + 7 ;

int n ;
vector < int > v[ MAXN ] ;
int subtree[ MAXN ] , freq[ MAXN ] , leaves[ MAXN ] ;

ll fac[ MAXN ] , inv[ MAXN ] ;

ll fastpow ( ll x , ll pw ) {
    ll ret = 1 ;
    while ( pw > 0 ) {
        if ( ( pw % 2 ) == 0 ) {
            x = ( x * x ) % MOD ;
            pw /= 2 ;
        }
        else {
            ret = ( ret * x ) % MOD ;
            -- pw ;
        }
    }
    return ret ;
}

ll comb ( int up , int down ) {
    if ( up < down || down < 0 ) { return 0 ; }
    ll ret = fac[ up ] ;
    ret = ( ret * inv[ down ] ) % MOD ;
    ret = ( ret * inv[ up - down ] ) % MOD ;
    return ret ;
}

void init ( int x , int prv ) {
    subtree[ x ] = 1 ;
    for ( auto y : v[ x ] ) {
        if ( y == prv ) { continue ; }
        init ( y , x ) ;
        subtree[ x ] += subtree[ y ] ;
        leaves[ x ] += leaves[ y ] ;
    }
    if ( subtree[ x ] == 1 ) {
        ++ leaves[ x ] ;
    }
}

vector < int > aux[ MAXN ] ;

void solve ( ) {
    cin >> n ;
    for ( int i = 1 , x , y ; i < n ; ++ i ) {
        cin >> x >> y ;
        v[ x ].push_back ( y ) ;
        v[ y ].push_back ( x ) ;
    }
    fac[ 0 ] = 1 ;
    for ( int i = 1 ; i <= n ; ++ i ) {
        fac[ i ] = ( fac[ i - 1 ] * i ) % MOD ;
    }
    inv[ n ] = fastpow ( fac[ n ] , MOD - 2 ) ;
    for ( int i = n - 1 ; i >= 0 ; -- i ) {
        inv[ i ] = ( inv[ i + 1 ] * ( i + 1 ) ) % MOD ;
    }
    int root = 0 ;
    for ( int i = 1 ; i <= n ; ++ i ) {
        if ( (int)v[ i ].size ( ) >= 2 ) {
            root = i ;
            break ;
        }
    }
    if ( root == 0 ) {
        cout << "4\n" ;
        return ;
    }
    init ( root , -1 ) ;
    for ( int i = 1 ; i <= n ; ++ i ) {
        ++ freq[ subtree[ i ] ] ;
        aux[ subtree[ i ] ].push_back ( leaves[ i ] ) ;
    }
    ll ans = 2 ;
    for ( int i = 1 ; i < n ; ++ i ) {
        if ( freq[ i ] == 0 && freq[ n - i ] == 0 ) { continue ; }
        if ( freq[ i - 1 ] >= 1 ) {
            ans = ( ans + comb ( n , i ) ) % MOD ;
        }
        else if ( freq[ n - i - 1 ] >= 1 ) {
            ans = ( ans + comb ( n , i ) ) % MOD ;
        }
        else {
            ll add = comb ( n , i ) ;
            for ( auto vals : aux[ i ] ) {
                if ( vals < n - i ) { continue ; }
                add = ( add + MOD - comb ( n - leaves[ 1 ] , vals - ( n - i ) ) ) % MOD ;
            }
            for ( auto vals : aux[ n - i ] ) {
                if ( vals < i ) { continue ; }
                add = ( add + MOD - comb ( n - leaves[ 1 ] , vals - i ) ) % MOD ;
            }
            ans = ( ans + add ) % MOD ;
        }
    }
    cout << ans << "\n" ;
}


int main ( ) {
    ios_base :: sync_with_stdio ( false ) ;
    cin.tie ( NULL ) ;
    int t = 1 ; // cin >> t ;
    while ( t -- ) { solve ( ) ; }
    return 0 ;
}
0