結果

問題 No.2471 Gemini Tree(Ver.Lapislazuli)
ユーザー Yerin JungYerin Jung
提出日時 2023-09-24 01:37:05
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 3,665 bytes
コンパイル時間 1,671 ms
コンパイル使用メモリ 169,452 KB
実行使用メモリ 25,104 KB
最終ジャッジ日時 2023-09-24 01:37:11
合計ジャッジ時間 4,802 ms
ジャッジサーバーID
(参考情報)
judge13 / judge14
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 5 ms
13,508 KB
testcase_01 AC 4 ms
13,516 KB
testcase_02 AC 5 ms
13,508 KB
testcase_03 AC 4 ms
13,524 KB
testcase_04 WA -
testcase_05 WA -
testcase_06 WA -
testcase_07 WA -
testcase_08 WA -
testcase_09 WA -
testcase_10 WA -
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 WA -
testcase_15 WA -
testcase_16 WA -
testcase_17 WA -
testcase_18 WA -
testcase_19 WA -
testcase_20 WA -
testcase_21 WA -
testcase_22 WA -
testcase_23 WA -
testcase_24 WA -
testcase_25 WA -
testcase_26 WA -
testcase_27 WA -
testcase_28 WA -
testcase_29 WA -
testcase_30 WA -
testcase_31 WA -
testcase_32 WA -
testcase_33 WA -
testcase_34 WA -
testcase_35 WA -
testcase_36 WA -
testcase_37 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<bits/stdc++.h>
using namespace std ;
typedef long long ll ;
typedef unsigned long long ull ;
typedef pair < int , int > pii ;
typedef vector<int> vi;
#define fi first
#define se second
mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define all(x) begin(x), end(x)
#define sz(x) (int)(x).size()

const int MOD = 998244353 ;
const int MAXN = 1e5 + 7 ;

int n ;
vector < int > v[ MAXN ] ;
int subtree[ MAXN ] , freq[ MAXN ] , leaves[ MAXN ] ;

int lst[ MAXN ] ;
ll fac[ MAXN ] , inv[ MAXN ] ;

ll fastpow ( ll x , ll pw ) {
    ll ret = 1 ;
    while ( pw > 0 ) {
        if ( ( pw % 2 ) == 0 ) {
            x = ( x * x ) % MOD ;
            pw /= 2 ;
        }
        else {
            ret = ( ret * x ) % MOD ;
            -- pw ;
        }
    }
    return ret ;
}

ll comb ( int up , int down ) {
    if ( up < down || down < 0 ) { return 0 ; }
    ll ret = fac[ up ] ;
    ret = ( ret * inv[ down ] ) % MOD ;
    ret = ( ret * inv[ up - down ] ) % MOD ;
    return ret ;
}

void init ( int x , int prv ) {
    subtree[ x ] = 1 ;
    for ( auto y : v[ x ] ) {
        if ( y == prv ) { continue ; }
        lst[ y ] = x ;
        init ( y , x ) ;
        subtree[ x ] += subtree[ y ] ;
        leaves[ x ] += leaves[ y ] ;
    }
    if ( subtree[ x ] == 1 ) {
        ++ leaves[ x ] ;
    }
}

vector < int > aux[ MAXN ] ;
vector < int > verts[ MAXN ] ;

void solve ( ) {
    cin >> n ;
    for ( int i = 1 , x , y ; i < n ; ++ i ) {
        cin >> x >> y ;
        v[ x ].push_back ( y ) ;
        v[ y ].push_back ( x ) ;
    }
    fac[ 0 ] = 1 ;
    for ( int i = 1 ; i <= n ; ++ i ) {
        fac[ i ] = ( fac[ i - 1 ] * i ) % MOD ;
    }
    inv[ n ] = fastpow ( fac[ n ] , MOD - 2 ) ;
    for ( int i = n - 1 ; i >= 0 ; -- i ) {
        inv[ i ] = ( inv[ i + 1 ] * ( i + 1 ) ) % MOD ;
    }
    int root = 0 ;
    for ( int i = 1 ; i <= n ; ++ i ) {
        if ( (int)v[ i ].size ( ) >= 2 ) {
            root = i ;
            break ;
        }
    }
    if ( root == 0 ) {
        cout << "4\n" ;
        return ;
    }
    init ( root , -1 ) ;
    for ( int i = 1 ; i <= n ; ++ i ) {
        ++ freq[ subtree[ i ] ] ;
        aux[ subtree[ i ] ].push_back ( leaves[ i ] ) ;
        verts[ subtree[ i ] ].push_back ( i ) ;
    }
    ll ans = 2 ;
    for ( int i = 1 ; i < n ; ++ i ) {
        if ( freq[ i ] == 0 && freq[ n - i ] == 0 ) { continue ; }
        if ( freq[ i ] + freq[ n - i ] > 1 ) {
            ans = ( ans + comb ( n , i ) ) % MOD ;
        }
        else {
            ll add = comb ( n , i ) ;
            ll bad = 0 ; 
            if ( freq[ i ] > 0 ) {
                if ( aux[ i ][ 0 ] <= i && leaves[ 1 ] - aux[ i ][ 0 ] <= n - i ) {
                    bad = ( bad + comb ( n - leaves[ 1 ] , i - aux[ i ][ 0 ] ) ) % MOD ;
                }
            }
            add = ( add + MOD - bad ) % MOD ;
            ans = ( ans + add ) % MOD ;
            bool done = false ;
            for ( auto x : verts[ i - 1 ] ) {
                if ( subtree[ lst[ x ] ] == i ) { continue ; }
                else { done = true ; break ; }
            }
            for ( auto x : verts[ n - i - 1 ] ) {
                if ( subtree[ lst[ x ] ] == n - i ) { continue ; }
                else { done = true ; break ; }                
            }
            if ( done == true ) {
                ans = ( ans + bad ) % MOD ;
            }
        }
    }
    cout << ans << "\n" ;
}


int main ( ) {
    ios_base :: sync_with_stdio ( false ) ;
    cin.tie ( NULL ) ;
    int t = 1 ; // cin >> t ;
    while ( t -- ) { solve ( ) ; }
    return 0 ;
}
0