結果

問題 No.2471 Gemini Tree(Ver.Lapislazuli)
ユーザー Yerin JungYerin Jung
提出日時 2023-09-24 00:51:49
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
WA  
実行時間 -
コード長 3,300 bytes
コンパイル時間 1,850 ms
コンパイル使用メモリ 171,444 KB
実行使用メモリ 21,192 KB
最終ジャッジ日時 2024-07-17 00:47:06
合計ジャッジ時間 4,749 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 5 ms
10,720 KB
testcase_01 AC 6 ms
10,824 KB
testcase_02 AC 6 ms
10,956 KB
testcase_03 AC 5 ms
10,852 KB
testcase_04 WA -
testcase_05 AC 5 ms
10,984 KB
testcase_06 WA -
testcase_07 WA -
testcase_08 WA -
testcase_09 WA -
testcase_10 WA -
testcase_11 WA -
testcase_12 WA -
testcase_13 WA -
testcase_14 WA -
testcase_15 WA -
testcase_16 WA -
testcase_17 WA -
testcase_18 WA -
testcase_19 WA -
testcase_20 WA -
testcase_21 WA -
testcase_22 WA -
testcase_23 WA -
testcase_24 WA -
testcase_25 WA -
testcase_26 WA -
testcase_27 WA -
testcase_28 WA -
testcase_29 WA -
testcase_30 WA -
testcase_31 WA -
testcase_32 AC 40 ms
19,784 KB
testcase_33 WA -
testcase_34 WA -
testcase_35 WA -
testcase_36 WA -
testcase_37 WA -
権限があれば一括ダウンロードができます

ソースコード

diff #

#include<bits/stdc++.h>
using namespace std ;
typedef long long ll ;
typedef unsigned long long ull ;
typedef pair < int , int > pii ;
typedef vector<int> vi;
#define fi first
#define se second
mt19937 rng(chrono::high_resolution_clock::now().time_since_epoch().count());

#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define all(x) begin(x), end(x)
#define sz(x) (int)(x).size()

const int MOD = 998244353 ;
const int MAXN = 1e5 + 7 ;

int n ;
vector < int > v[ MAXN ] ;
int subtree[ MAXN ] , freq[ MAXN ] , leaves[ MAXN ] ;

ll fac[ MAXN ] , inv[ MAXN ] ;

ll fastpow ( ll x , ll pw ) {
    ll ret = 1 ;
    while ( pw > 0 ) {
        if ( ( pw % 2 ) == 0 ) {
            x = ( x * x ) % MOD ;
            pw /= 2 ;
        }
        else {
            ret = ( ret * x ) % MOD ;
            -- pw ;
        }
    }
    return ret ;
}

ll comb ( int up , int down ) {
    if ( up < down || down < 0 ) { return 0 ; }
    ll ret = fac[ up ] ;
    ret = ( ret * inv[ down ] ) % MOD ;
    ret = ( ret * inv[ up - down ] ) % MOD ;
    return ret ;
}

void init ( int x , int prv ) {
    subtree[ x ] = 1 ;
    for ( auto y : v[ x ] ) {
        if ( y == prv ) { continue ; }
        init ( y , x ) ;
        subtree[ x ] += subtree[ y ] ;
        leaves[ x ] += leaves[ y ] ;
    }
    if ( subtree[ x ] == 1 ) {
        ++ leaves[ x ] ;
    }
}

vector < int > aux[ MAXN ] ;

void solve ( ) {
    cin >> n ;
    for ( int i = 1 , x , y ; i < n ; ++ i ) {
        cin >> x >> y ;
        v[ x ].push_back ( y ) ;
        v[ y ].push_back ( x ) ;
    }
    fac[ 0 ] = 1 ;
    for ( int i = 1 ; i <= n ; ++ i ) {
        fac[ i ] = ( fac[ i - 1 ] * i ) % MOD ;
    }
    inv[ n ] = fastpow ( fac[ n ] , MOD - 2 ) ;
    for ( int i = n - 1 ; i >= 0 ; -- i ) {
        inv[ i ] = ( inv[ i + 1 ] * ( i + 1 ) ) % MOD ;
    }
    int root = 0 ;
    for ( int i = 1 ; i <= n ; ++ i ) {
        if ( (int)v[ i ].size ( ) >= 2 ) {
            root = i ;
            break ;
        }
    }
    if ( root == 0 ) {
        cout << "4\n" ;
        return ;
    }
    init ( root , -1 ) ;
    for ( int i = 1 ; i <= n ; ++ i ) {
        ++ freq[ subtree[ i ] ] ;
        aux[ subtree[ i ] ].push_back ( leaves[ i ] ) ;
    }
    ll ans = 1 ;
    for ( int i = 1 ; i <= n - i ; ++ i ) {
        if ( freq[ i ] == 0 && freq[ n - i ] == 0 ) { continue ; }
        int tot = freq[ i - 1 ] ;
        if ( ( i - 1 ) != n - i ) { tot += freq[ n - i ] ; }
        if ( tot > 1 ) {
            ans = ( ans + comb ( n , i ) ) % MOD ;
        }
        else {
            ll add = comb ( n , i ) ;
            for ( auto vals : aux[ i - 1 ] ) {
                if ( vals < i ) { continue ; }
                add = ( add + MOD - comb ( n - leaves[ 1 ] , vals - i ) ) % MOD ;
            }
            if ( i - 1 != n - i ) { 
                for ( auto vals : aux[ n - i ] ) {
                    if ( vals < i ) { continue ; }
                    add = ( add + MOD - comb ( n - leaves[ 1 ] , vals - i ) ) % MOD ;
                }
            }
            ans = ( ans + add ) % MOD ;
        }
    }
    ans = ( ans * 2 ) % MOD ;
    cout << ans << "\n" ;
}


int main ( ) {
    ios_base :: sync_with_stdio ( false ) ;
    cin.tie ( NULL ) ;
    int t = 1 ; // cin >> t ;
    while ( t -- ) { solve ( ) ; }
    return 0 ;
}
0