結果

問題 No.3346 Tree to DAG
コンテスト
ユーザー rhoo
提出日時 2025-11-13 23:18:35
言語 Rust
(1.83.0 + proconio)
結果
AC  
実行時間 181 ms / 2,000 ms
コード長 12,567 bytes
コンパイル時間 15,341 ms
コンパイル使用メモリ 399,844 KB
実行使用メモリ 62,052 KB
最終ジャッジ日時 2025-11-13 23:18:56
合計ジャッジ時間 18,622 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 39
権限があれば一括ダウンロードができます

ソースコード

diff #

#![allow(unused_imports,non_snake_case,dead_code)]
use std::{cmp::Reverse as Rev,ops::Range,collections::*,iter::*,mem::swap};
use proconio::{marker::*,*};



#[fastout]
fn main(){
    input!{
        n:usize,
        es:[(Usize1,Usize1);n-1],
    }

    let mut g=vec![vec![];n];
    for &(u,v) in &es{
        g[u].push(v);
        g[v].push(u);
    }

    cap!{
        #![&g:Vec<Vec<usize>>]
        fn dist(p:usize,v:usize,d:&mut [usize]){
            if p==!0{
                d[v]=0;
            } else{
                d[v]=d[p]+1;
            }

            for &nv in &g[v]{
                if nv!=p{
                    dist!(v,nv,d);
                }
            }
        }
    }

    let mut a=vec![0;n];
    dist!(!0,0,&mut a);
    
    let u=(0..n).max_by_key(|&i|a[i]).unwrap();
    let mut ud=vec![0;n];
    dist!(!0,u,&mut ud);

    let v=(0..n).max_by_key(|&i|ud[i]).unwrap();
    let mut vd=vec![0;n];
    dist!(!0,v,&mut vd);

    let lca=LCA::new(&g,0);
    let INF=1e8 as usize;
    let mut cur=([0;3],INF);
    
    for m in 0..n{
        if m==u || m==v{
            continue;
        }

        let mut a=[lca.get_dist(u,v),lca.get_dist(u,m),lca.get_dist(m,v)];
        a.sort();
        let es=lca.get_virtual_tree(&[u,v,m]);

        let mut tot=0;
        for &(u,v) in &es{
            tot+=lca.get_dist(u,v);
        }

        if cur.0<a || cur.0<=a && tot<=cur.1{
            cur=(a,tot);
        }
    }

    let n=n as u64;
    let f=|a:usize|{
        let a=a as u64;
        M::new(2).pow((n+2)-(a+2)+1)
    };
    let g=|a:usize|{
        let a=a as u64;
        M::new(2).pow((n+2)-(a+3)+1)
    };

    let mut ans=M::new(0);
    ans+=M::new(2).pow(n+2);
    ans-=f(cur.0[0])+f(cur.0[1])+f(cur.0[2]);
    ans+=g(cur.1)*3;

    println!("{ans}");
}


type M=ModInt998244353;



#[macro_export]
macro_rules! cap{
    ()=>{};
    (#![] $($fn:tt)*)=>{
        cap!([],[],[],[],[$($fn)*],[],[]);
    };
    (#![$($g:tt)*] $($fn:tt)*)=>{
        cap!([$($g)*,],[],[],[],[$($fn)*],[],[]);
    };
    ([$name:ident:$t:ty,$($rem:tt)*],[$($ga:tt)*],[$($gn:tt)*],[$($ge:tt)*],[$($fn:tt)*],[],[])=>{
        cap!([$($rem)*],[$name:$t,$($ga)*],[$name,$($gn)*],[$name,$($ge)*],[$($fn)*],[],[]);
    };
    ([$($flag:tt)? $name:ident:$t:ty,$($rem:tt)*],[$($ga:tt)*],[$($gn:tt)*],[$($ge:tt)*],[$($fn:tt)*],[],[])=>{
        cap!([$($rem)*],[$name:$($flag)?$t,$($ga)*],[$name,$($gn)*],[$($flag)?$name,$($ge)*],[$($fn)*],[],[]);
    };
    ([&mut $name:ident:$t:ty,$($rem:tt)*],[$($ga:tt)*],[$($gn:tt)*],[$($ge:tt)*],[$($fn:tt)*],[],[])=>{
        cap!([$($rem)*],[$name:&mut $t,$($ga)*],[$name,$($gn)*],[&mut $name,$($ge)*],[$($fn)*],[],[]);
    };
    ([$(,)?],[$($ga:tt)*],[$($gn:tt)*],[$($ge:tt)*],[$(#[$($att:tt)*])? fn $name:ident($($arg:tt)*) $body:block $($rem:tt)*],[$($info:tt)*],[$($fn:tt)*])=>{
        cap!([],[$($ga)*],[$($gn)*],[$($ge)*],[$($rem)*],[(($(#[$($att)*])?),$name,($($arg)*),(),$body),$($info)*],[$name,$($fn)*]);
    };
    ([$(,)?],[$($ga:tt)*],[$($gn:tt)*],[$($ge:tt)*],[$(#[$($att:tt)*])? fn $name:ident($($arg:tt)*)->$ret:ty $body:block $($rem:tt)*],[$($info:tt)*],[$($fn:tt)*])=>{
        cap!([],[$($ga)*],[$($gn)*],[$($ge)*],[$($rem)*],[(($(#[$($att)*])?),$name,($($arg)*),$ret,$body),$($info)*],[$name,$($fn)*]);
    };
    ([$(,)?],[$($ga:tt)*],[$($gn:tt)*],[$($ge:tt)*],[],[$($info:tt)*],[$($fn:tt)*])=>{
        cap!(@make_fn [$($ga)*],[$($gn)*],[$($ge)*],[$($info)*],[$($fn)*]);
    };
    (@make_fn [$($ga:ident:$gt:ty,)*],[$($gn:tt)*],[$($ge:tt)*],[(($($att:tt)*),$name:ident,($($arg:tt)*),$ret:ty,$body:block),$($rem:tt)*],[$($fn:tt)*])=>{
        $($att)*
        fn $name($($ga:$gt,)*$($arg)*)->$ret{
            $(#[allow(unused_variables)] let $ga=$ga;)*
            cap!(@make_macros ($),[$($gn)*],[$($fn)*]);
            $body
        }
        cap!(@make_fn [$($ga:$gt,)*],[$($gn)*],[$($ge)*],[$($rem)*],[$($fn)*]);
    };
    (@make_fn [$($ga:tt)*],[$($gn:tt)*],[$($ge:tt)*],[],[$($fn:tt)*])=>{
        cap!(@make_global_macros ($),[$($ge)*],[$($fn)*]);
    };
    (@make_macros ($dol:tt),[$($gn:ident,)*],[$name:ident,$($rem:tt)*])=>{
        #[allow(unused_macros)]
        macro_rules! $name{
            ($dol($dol arg:expr),*)=>{$name($($gn,)* $dol($dol arg),*)}
        }
        cap!(@make_macros ($),[$($gn,)*],[$($rem)*]);
    };
    (@make_macros ($dol:tt),[$($gn:ident,)*],[])=>{};
    (@make_global_macros ($dol:tt),[$($ge:expr,)*],[$name:ident,$($rem:tt)*])=>{
        #[allow(unused_macros)]
        macro_rules! $name{
            ($dol($dol arg:expr),*)=>{$name($($ge,)* $dol($dol arg),*)}
        }
        cap!(@make_global_macros ($),[$($ge,)*],[$($rem)*]);
    };
    (@make_global_macros ($dol:tt),[$($ge:expr,)*],[])=>{};
}



type ModInt998244353=ModInt<998244353>;
type ModInt1000000007=ModInt<1000000007>;



#[derive(Clone,Copy,PartialEq,Eq,Default,Hash)]
struct ModInt<const MOD:u32>(u32);
impl<const MOD:u32> ModIntBase for ModInt<MOD>{
    fn modulus()->u32{ MOD }
    fn val(self)->u32{ self.0 }
    fn new(v:impl RemU32)->Self{ Self(v.rem_u32(MOD)) }

    fn inv(self)->Self{
        assert!(self.0!=0);

        let (mut a,mut b)=(self.0 as i64,MOD as i64);
        let (mut u,mut v)=(1,0);
        while b!=0{
            let t=a/b;
            (a,b)=(b,a-t*b);
            (u,v)=(v,u-t*v);
        }
        assert!(a==1);

        if u<0{
            u+=MOD as i64;
        }
        Self(u as u32)
    }

    fn pow(self,mut k:u64)->Self{
        let mut pow2=self;
        let mut ret=Self(1);
        while k>0{
            if k&1==1{
                ret*=pow2;
            }
            pow2*=pow2;
            k>>=1;
        }
        ret
    }
}


impl<const MOD:u32> std::fmt::Display for ModInt<MOD>{
    fn fmt(&self,f:&mut std::fmt::Formatter)->std::fmt::Result{
        write!(f,"{}",self.0)
    }
}
impl<const MOD:u32> std::fmt::Debug for ModInt<MOD>{
    fn fmt(&self,f:&mut std::fmt::Formatter)->std::fmt::Result{
        write!(f,"{}",self.0)
    }
}


impl<const MOD:u32> std::ops::Add for ModInt<MOD>{
    type Output=Self;
    fn add(self,a:Self)->Self{
        let mut new=self.0+a.0;
        if MOD<=new{
            new-=MOD;
        }
        Self(new)
    }
}
impl<const MOD:u32> std::ops::Sub for ModInt<MOD>{
    type Output=Self;
    fn sub(self,a:Self)->Self{
        let mut new=self.0-a.0;
        if 0>new as i32{
            new+=MOD;
        }
        Self(new)
    }
}
impl<const MOD:u32> std::ops::Mul for ModInt<MOD>{
    type Output=Self;
    fn mul(self,a:Self)->Self{
        Self((self.0 as u64*a.0 as u64%MOD as u64) as u32)
    }
}
impl<const MOD:u32> std::ops::Div for ModInt<MOD>{
    type Output=Self;
    fn div(self,a:Self)->Self{
        self*a.inv()
    }
}
impl<const MOD:u32> std::ops::Neg for ModInt<MOD>{
    type Output=Self;
    fn neg(self)->Self{
        if self.0==0{
            return self;
        }
        Self(MOD-self.0)
    }
}


impl<const MOD:u32> std::str::FromStr for ModInt<MOD>{
    type Err=<u64 as std::str::FromStr>::Err;
    fn from_str(s:&str)->Result<Self,Self::Err>{
        let x=s.parse::<u64>()?;
        Ok(Self::new(x))
    }
}


macro_rules! impl_modint_ops{
    ($trait:ident,$func:ident,$assign_trait:ident,$assign_func:ident,$op:tt)=>{
        impl<const MOD:u32> std::ops::$assign_trait for ModInt<MOD>{
            fn $assign_func(&mut self,a:Self){
                *self=*self $op a
            }
        }
        impl<T:RemU32,const MOD:u32> std::ops::$trait<T> for ModInt<MOD>{
            type Output=Self;
            fn $func(self,a:T)->Self{
                self $op Self::new(a)
            }
        }
        impl<T:RemU32,const MOD:u32> std::ops::$assign_trait<T> for ModInt<MOD>{
            fn $assign_func(&mut self,a:T){
                *self=*self $op Self::new(a)
            }
        }
    }
}
impl_modint_ops!(Add,add,AddAssign,add_assign,+);
impl_modint_ops!(Sub,sub,SubAssign,sub_assign,-);
impl_modint_ops!(Mul,mul,MulAssign,mul_assign,*);
impl_modint_ops!(Div,div,DivAssign,div_assign,/);


impl<const MOD:u32> std::iter::Sum for ModInt<MOD>{
    fn sum<I:Iterator<Item=Self>>(iter:I)->Self{
        iter.fold(Self(0),|sum,x|sum+x)
    }
}
impl<const MOD:u32> std::iter::Product for ModInt<MOD>{
    fn product<I:Iterator<Item=Self>>(iter:I)->Self{
        iter.fold(Self(1),|prod,x|prod*x)
    }
}


trait RemU32{
    fn rem_u32(self,m:u32)->u32;
}
macro_rules! impl_rem_u32{
    ($($ty:tt),*)=>{
        $(
            impl RemU32 for $ty{
                fn rem_u32(self,m:u32)->u32{
                    (self as i64).rem_euclid(m as i64) as _
                }
            }
        )*
    }
}
impl_rem_u32!(i32,i64);


macro_rules! impl_rem_u32{
    ($($ty:tt),*)=>{
        $(
            impl RemU32 for $ty{
                fn rem_u32(self,m:u32)->u32{
                    (self%(m as $ty)) as _
                }
            }
        )*
    }
}
impl_rem_u32!(u32,u64,usize);


trait ModIntBase:Default+std::str::FromStr+Copy+Eq+std::hash::Hash+std::fmt::Display+std::fmt::Debug
    +std::ops::Neg<Output=Self>+std::ops::Add<Output=Self>+std::ops::Sub<Output=Self>
    +std::ops::Mul<Output=Self>+std::ops::Div<Output=Self>
    +std::ops::AddAssign+std::ops::SubAssign+std::ops::MulAssign+std::ops::DivAssign
{
    fn modulus()->u32;
    fn val(self)->u32;
    fn new(v:impl RemU32)->Self;
    fn inv(self)->Self;
    fn pow(self,k:u64)->Self;
}



#[derive(Clone)]
struct LCA{
    range:Vec<(usize,usize)>,
    sparse_table:Vec<Vec<usize>>,
    rank:Vec<usize>,
}
impl LCA{
    fn new(tree:&Vec<Vec<usize>>,root:usize)->LCA{
        let n=tree.len();

        let mut tour=vec![];
        let mut range=vec![(!0,!0);n];
        let mut rank=vec![0;n];
        
        fn rec(tree:&Vec<Vec<usize>>,p:usize,v:usize,tour:&mut Vec<usize>,range:&mut [(usize,usize)],rank:&mut [usize]){
            range[v].0=tour.len();
            tour.push(v);
            for &nv in &tree[v]{
                if nv!=p{
                    rank[nv]=rank[v]+1;
                    rec(tree,v,nv,tour,range,rank);
                }
                range[v].1=tour.len();
                tour.push(v);
            }
        }

        rec(tree,!0,root,&mut tour,&mut range,&mut rank);

        let min=|u:usize,v:usize|->usize{
            if rank[u]<rank[v]{
                u
            } else{
                v
            }
        };
        
        let mut sparse_table=vec![tour];
        for t in 0..{
            let step=1<<t;
            let last=sparse_table.last().unwrap();

            if last.len()<=step{
                break;
            }
            
            let new=(step..last.len()).map(|i|min(last[i-step],last[i])).collect();
            sparse_table.push(new);
        }

        LCA{range,sparse_table,rank}
    }

    fn get_lca(&self,mut u:usize,mut v:usize)->usize{
        let rank=&self.rank;
        if rank[u]>rank[v]{
            (u,v)=(v,u);
        }
        
        let ru=self.range[u];
        let rv=self.range[v];

        if ru.0<=rv.0 && rv.1<=ru.1{
            return u;
        }

        let mut l=ru.0;
        let mut r=rv.0;

        if l>r{
            (l,r)=(r,l);
        }

        let step=(usize::BITS-1-(r-l).leading_zeros()) as usize;
        let a=self.sparse_table[step][l];
        let b=self.sparse_table[step][r-(1<<step)];

        if rank[a]<rank[b]{
            a
        } else{
            b
        }
    }

    fn get_dist(&self,u:usize,v:usize)->usize{
        let lca=self.get_lca(u,v);
        self.rank[u]+self.rank[v]-self.rank[lca]*2
    }

    fn get_virtual_tree(&self,vs:&[usize])->Vec<(usize,usize)>{
        if vs.is_empty(){
            return vec![];
        }
        
        let rank=&self.rank;

        let mut vs=vs.to_vec();
        vs.sort_unstable_by_key(|&v|self.range[v].1);

        let mut es=vec![];
        let mut stack=vec![vs[0]];

        for &v in &vs[1..]{
            let lca=self.get_lca(*stack.last().unwrap(),v);
            let mut prev=!0;
            while let Some(&v)=stack.last(){
                if rank[lca]>rank[v]{
                    break;
                }

                let nv=stack.pop().unwrap();
                if prev!=!0{
                    es.push((nv,prev));
                }
                prev=nv;
            }

            if prev!=!0 && lca!=prev{
                es.push((lca,prev));
            }
            stack.push(lca);
            if lca!=v{
                stack.push(v);
            }
        }

        for w in stack.windows(2){
            es.push((w[1],w[0]));
        }

        es
    }
}
0