結果

問題 No.3346 Tree to DAG
コンテスト
ユーザー 2251799813685248
提出日時 2025-11-13 23:47:46
言語 C++17
(gcc 13.3.0 + boost 1.87.0)
結果
WA  
実行時間 -
コード長 3,883 bytes
コンパイル時間 1,372 ms
コンパイル使用メモリ 124,472 KB
実行使用メモリ 15,872 KB
最終ジャッジ日時 2025-11-13 23:47:51
合計ジャッジ時間 4,787 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 28 WA * 11
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <iostream>
#include <vector>
#include <string>
#include <cmath>
#include <set>
#include <unordered_map>
#include <map>
#include <unordered_set>
#include <queue>
#include <algorithm>
#include <iomanip>
#include <cassert>
#include <functional>


using namespace std;
#define ll long long
#define MOD 998244353
#define ld long double
#define INF 2251799813685248
#define vall(A) A.begin(),A.end()
#define gridinput(vv,H,W) for (ll i = 0; i < H; i++){string T; cin >> T; for(ll j = 0; j < W; j++){vv[i][j] = {T[j]};}}
#define adjustedgridinput(vv,H,W) for (ll i = 1; i <= H; i++){string T; cin >> T; for(ll j = 1; j <= W; j++){vv[i][j] = {T[j-1]};}}
#define vin(A) for (ll i = 0, sz = A.size(); i < sz; i++){cin >> A[i];}
#define vout(A) for(ll i = 0, sz = A.size(); i < sz; i++){cout << A[i] << " \n"[i == sz-1];}
#define adjustedvin(A) for (ll i = 1, sz = A.size(); i < sz; i++){cin >> A[i];}
#define adjustedvout(A) for(ll i = 1, sz = A.size(); i < sz; i++){cout << A[i] << " \n"[i == sz-1];}
#define vout2d(A,H,W) for (ll i = 0; i < H; i++){for (ll j = 0; j < W; j++){cout << A[i][j] << " \n"[j==W-1];}}
#define encode(i,j) ((i)<<32)+j
#define decode(v,w) (w ? (v)%4294967296 : (v)>>32)
vector<ll> pow2ll{1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072,262144,524288,1048576,2097152,4194304,8388608,16777216,33554432,67108864,134217728,268435456,536870912,1073741824,2147483648,4294967296};
vector<ll> pow10ll{1,10,100,1000,10000,100000,1000000,10000000,100000000,1000000000,10000000000,100000000000,1000000000000,10000000000000,100000000000000,1000000000000000,10000000000000000,100000000000000000,1000000000000000000};
vector<ll> di{0,1,0,-1};
vector<ll> dj{1,0,-1,0};

/// @brief a^bをmで割った余りを返す。bに関して対数時間で計算できる。
/// @param a 
/// @param b 
/// @param m 
/// @return a^b%m
ll modpow(ll a, ll b, ll m){
    ll t = a;
    ll ans = 1;
    while (b > 0){
        if (b%2){
            ans = (ans*t)%m;
        }
        b /= 2;
        t = (t*t)%m;
    }
    return ans;
}


void dfs(vector<vector<ll>> &E, vector<ll> &D, ll now){
    for (auto v : E[now]){
        if (D[v] != -1){
            continue;
        }
        D[v] = D[now]+1;
        dfs(E,D,v);
    }
}


int main(){
    ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    ll N;
    cin >> N;
    vector<vector<ll>> E(N+1);
    for (ll i = 0; i < N-1; i++){
        ll u,v;
        cin >> u >> v;
        E[u].push_back(v);
        E[v].push_back(u);
    }
    vector<ll> D(N+1,-1);
    vector<ll> D2(N+1,-1);
    vector<ll> D3(N+1,-1);
    D[1] = 0;
    dfs(E,D,1);
    ll dist_max_v = -1;
    ll d_max = -1;
    for (ll i = 1; i <= N; i++){
        if (D[i] >= d_max){
            d_max = D[i];
            dist_max_v = i;
        }
    }

    D2[dist_max_v] = 0;
    dfs(E,D2,dist_max_v);
    ll dist_max_v2 = -1;
    d_max = -1;
    for (ll i = 1; i <= N; i++){
        if (D2[i] >= d_max){
            d_max = D2[i];
            dist_max_v2 = i;
        }
    }

    D3[dist_max_v2] = 0;
    dfs(E,D3,dist_max_v2);



    vector<vector<ll>> ans;
    ll b = -1;
    ll t = 0;
    ll diameter = D2[dist_max_v2];
    for (ll i = 1; i <= N; i++){
        if (i == dist_max_v || i == dist_max_v2){continue;}
        if (t < D2[i]+D3[i]){
            t = D2[i] + D3[i];
            b = i;
        }
    }

    D = vector<ll>(N+1,-1);
    D[b] = 0;
    dfs(E,D,b);
    ll c = 0;
    ll u = __LONG_LONG_MAX__;
    for (ll i = 1; i <= N; i++){
        if (D2[i] + D3[i] == diameter){
            if (u >= D[i]){
                u = D[i];
                c = i;
            }
        }
    }

    ll k = D3[c];
    ll m = D2[c];
    ll l = (t-diameter)/2;

    ll answer = modpow(2,N+2,MOD) - modpow(2,N-k-l-m,MOD)*((2*(modpow(2,k,MOD) + modpow(2,l,MOD) + modpow(2,m,MOD))-3)%MOD);



    cout << (MOD+(answer%MOD))%MOD << "\n";
}
0