fn main() { let (a, e) = read(); let n = a.len(); let mut solver = RerootingDP::new( R, vec![0; n], e.into_iter().map(|e| (e.0 - 1, e.1 - 1, ())).collect(), ); let mut ord = (0..n).collect::>(); ord.sort_by_key(|v| a[*v]); let mut ans = M::zero(); for x in ord { solver.set_vertex(x, 1); let p = solver.find(x); ans += M::new(a[x]) * (p.con + p.dis).1; } println!("{}", ans); } type M = ModInt<998244353>; #[derive(Clone, Debug)] struct Data { con: Dual, dis: Dual, invalid: M, } struct R; impl TreeDP for R { type Vertex = u8; type Edge = (); type Path = Data; type Point = (Dual, M); fn vertex(&self, v: &Self::Vertex) -> Self::Path { if *v == 0 { Data { con: Dual::zero(), dis: Dual::zero(), invalid: M::one(), } } else { Data { con: Dual::new(M::one(), M::one()), dis: Dual::zero(), invalid: M::zero(), } } } fn add_edge(&self, p: &Self::Path, e: &Self::Edge) -> Self::Point { let invalid = p.invalid; let mut s = p.con + p.dis; s.0 = s.0 + s.0 + p.invalid; (s, invalid) } fn rake(&self, a: &Self::Point, b: &Self::Point) -> Self::Point { let (x, y) = *a; let (z, w) = *b; (x * z, (x.0 + y) * (z.0 + w) - x.0 * z.0) } fn add_vertex(&self, p: &Self::Point, v: &Self::Vertex) -> Self::Path { let (a, b) = *p; if *v == 0 { Data { con: Dual::zero(), dis: Dual::zero(), invalid: a.0 + b, } } else { Data { con: Dual::new(a.0, a.1 + a.0), dis: Dual::zero(), invalid: b, } } } fn compress(&self, p: &Self::Path, c: &Self::Path, e: &Self::Edge) -> Self::Path { let con = p.con * c.con; let mut dis = (p.con + p.dis) * Dual::new((c.con + c.dis).0 + c.invalid, M::zero()); dis = dis + p.con * c.dis; dis = dis + p.dis * Dual::new((c.con + c.dis).0 + c.invalid, M::zero()); let mut invalid = p.invalid * ((c.con + c.dis).0 + c.invalid) * M::new(2); invalid += p.con.0 * c.invalid; Data { con, dis, invalid, } } } fn read() -> (Vec, Vec<(usize, usize)>) { let mut s = String::new(); use std::io::*; std::io::stdin().read_to_string(&mut s).unwrap(); let mut it = s.trim().split_whitespace().flat_map(|s| s.parse::()); let mut next = || it.next().unwrap(); let n = next(); let a = (0..n).map(|_| next() as u32).collect(); let e = (1..n) .map(|_| { let a = next(); let b = next(); (a, b) }) .collect(); (a, e) } pub trait TreeDP { type Vertex: Clone; type Edge: Clone; type Path: Clone; type Point: Clone; fn vertex(&self, v: &Self::Vertex) -> Self::Path; fn add_edge(&self, p: &Self::Path, e: &Self::Edge) -> Self::Point; fn rake(&self, a: &Self::Point, b: &Self::Point) -> Self::Point; fn add_vertex(&self, p: &Self::Point, v: &Self::Vertex) -> Self::Path; fn compress(&self, p: &Self::Path, c: &Self::Path, e: &Self::Edge) -> Self::Path; } pub struct RerootingDP { op: R, v: Vec, e: Vec, sum: Vec>, stt: StaticTopTree, } impl RerootingDP where R: TreeDP, { const ROOT: usize = 0; pub fn new(op: R, v: Vec, edge: Vec<(usize, usize, R::Edge)>) -> Self { assert!(v.len() == edge.len() + 1); let mut e = vec![]; let mut memo = vec![]; for (a, b, w) in edge { e.push(w); memo.push((a, b)); } let stt = StaticTopTree::new(memo, Self::ROOT); let sum = vec![Union::V((op.vertex(&v[0]), op.vertex(&v[0]))); stt.label.len()]; let mut res = Self { op, v, e, sum, stt }; for i in 0..res.stt.label.len() { res.pull(i); } res } pub fn set_vertex(&mut self, v: usize, w: R::Vertex) { self.v[v] = w; self.update(self.stt.vertex[v]); } pub fn set_edge(&mut self, e: usize, w: R::Edge) { self.e[e] = w; self.update(self.stt.edge[e]); } pub fn find(&self, root: usize) -> R::Path { if root == Self::ROOT { return self.sum.last().unwrap().get_v().0.clone(); } // なんか非常に汚い、もっと綺麗に書けないか let mut pos = self.stt.vertex[root]; let mut memo = vec![]; while let Some(p) = self.stt.node[pos].p.get() { let l = self.stt.node[p].l.get().unwrap() == pos; pos = p; memo.push((p, l)); } let mut up: Option<(R::Path, R::Edge)> = None; let mut down: Option<(R::Path, R::Edge)> = None; let mut point: Option = None; let mut vertex: Option = None; for &(pos, left) in memo.iter().rev() { if self.stt.label[pos] == STTLabel::Compress { let e = &self.e[self.stt.node[pos].e.get().unwrap()]; if left { let r = self.stt.node[pos].r.get().unwrap(); let r = &self.sum[r].get_v().0; down = Some(down.map_or((r.clone(), e.clone()), |(a, b)| { (self.op.compress(r, &a, &b), e.clone()) })); } else { let l = self.stt.node[pos].l.get().unwrap(); let l = &self.sum[l].get_v().1; up = Some(up.map_or((l.clone(), e.clone()), |(a, b)| { (self.op.compress(l, &a, &b), e.clone()) })); } } else if self.stt.label[pos] == STTLabel::AddVertex { vertex = Some(self.v[self.stt.node[pos].e.get().unwrap()].clone()); let u = up.take().map(|p| self.op.add_edge(&p.0, &p.1)); let d = down.take().map(|p| self.op.add_edge(&p.0, &p.1)); point = match (u, d) { (Some(a), Some(b)) => Some(self.op.rake(&a, &b)), (a, b) => a.or(b), }; } else if self.stt.label[pos] == STTLabel::Rake { let other = if left { self.stt.node[pos].r } else { self.stt.node[pos].l } .get() .unwrap(); let p = self.sum[other].get_e(); point = Some(point.map_or(p.clone(), |q| self.op.rake(p, &q))); } else if self.stt.label[pos] == STTLabel::AddEdge { let e = &self.e[self.stt.node[pos].e.get().unwrap()]; if point.is_some() { let p = point.take().unwrap(); let v = vertex.take().unwrap(); up = Some((self.op.add_vertex(&p, &v), e.clone())); } else { unreachable!() } } else { unreachable!() } } let pos = self.stt.vertex[root]; if self.stt.label[pos] == STTLabel::AddVertex { let u = up.map(|p| self.op.add_edge(&p.0, &p.1)); let d = down.map(|p| self.op.add_edge(&p.0, &p.1)); let p = match (u, d) { (Some(a), Some(b)) => Some(self.op.rake(&a, &b)), (a, b) => a.or(b), } .unwrap(); let c = self.sum[self.stt.node[pos].l.get().unwrap()].get_e(); let q = self.op.rake(&p, c); self.op.add_vertex(&q, &self.v[root]) } else { let u = up.map(|p| self.op.add_edge(&p.0, &p.1)); let d = down.map(|p| self.op.add_edge(&p.0, &p.1)); let p = match (u, d) { (Some(a), Some(b)) => Some(self.op.rake(&a, &b)), (a, b) => a.or(b), } .unwrap(); self.op.add_vertex(&p, &self.v[root]) } } fn update(&mut self, mut v: usize) { self.pull(v); while let Some(p) = self.stt.node[v].p.get() { v = p; self.pull(p); } } fn pull(&mut self, v: usize) { match self.stt.label[v] { STTLabel::Vertex => { let u = self.stt.node[v].e.get().unwrap(); let p = self.op.vertex(&self.v[u]); self.sum[v].set_v((p.clone(), p)); } STTLabel::AddEdge => { let l = self.stt.node[v].l.get().unwrap(); let e = self.stt.node[v].e.get().unwrap(); let path = &self.sum[l].get_v().0; let point = self.op.add_edge(path, &self.e[e]); self.sum[v].set_e(point); } STTLabel::Rake => { let l = self.stt.node[v].l.get().unwrap(); let r = self.stt.node[v].r.get().unwrap(); let point = self.op.rake(self.sum[l].get_e(), self.sum[r].get_e()); self.sum[v].set_e(point); } STTLabel::AddVertex => { let l = self.stt.node[v].l.get().unwrap(); let u = self.stt.node[v].e.get().unwrap(); let path = self.op.add_vertex(self.sum[l].get_e(), &self.v[u]); self.sum[v].set_v((path.clone(), path)); } STTLabel::Compress => { let l = self.sum[self.stt.node[v].l.get().unwrap()].get_v(); let r = self.sum[self.stt.node[v].r.get().unwrap()].get_v(); let e = self.stt.node[v].e.get().unwrap(); let lr = self.op.compress(&l.0, &r.0, &self.e[e]); let rl = self.op.compress(&r.1, &l.1, &self.e[e]); self.sum[v].set_v((lr, rl)); } } } } pub struct FixRootTreeDP { op: R, v: Vec, e: Vec, sum: Vec>, stt: StaticTopTree, } impl FixRootTreeDP where R: TreeDP, { pub fn new(op: R, v: Vec, edge: Vec<(usize, usize, R::Edge)>) -> Self { assert!(v.len() == edge.len() + 1); let mut e = vec![]; let mut memo = vec![]; for (a, b, w) in edge { e.push(w); memo.push((a, b)); } let stt = StaticTopTree::new(memo, 0); let sum = vec![Union::V(op.vertex(&v[0])); stt.label.len()]; let mut res = Self { op, v, e, sum, stt }; for i in 0..res.stt.label.len() { res.pull(i); } res } pub fn set_vertex(&mut self, v: usize, w: R::Vertex) { self.v[v] = w; self.update(self.stt.vertex[v]); } pub fn set_edge(&mut self, e: usize, w: R::Edge) { self.e[e] = w; self.update(self.stt.edge[e]); } pub fn find(&self) -> R::Path { self.sum.last().unwrap().get_v().clone() } fn update(&mut self, mut v: usize) { self.pull(v); while let Some(p) = self.stt.node[v].p.get() { v = p; self.pull(p); } } fn pull(&mut self, v: usize) { match self.stt.label[v] { STTLabel::Vertex => { let u = self.stt.node[v].e.get().unwrap(); self.sum[v].set_v(self.op.vertex(&self.v[u])); } STTLabel::AddEdge => { let l = self.stt.node[v].l.get().unwrap(); let e = self.stt.node[v].e.get().unwrap(); let path = self.sum[l].get_v(); let point = self.op.add_edge(path, &self.e[e]); self.sum[v].set_e(point); } STTLabel::Rake => { let l = self.stt.node[v].l.get().unwrap(); let r = self.stt.node[v].r.get().unwrap(); let point = self.op.rake(self.sum[l].get_e(), self.sum[r].get_e()); self.sum[v].set_e(point); } STTLabel::AddVertex => { let l = self.stt.node[v].l.get().unwrap(); let u = self.stt.node[v].e.get().unwrap(); let path = self.op.add_vertex(self.sum[l].get_e(), &self.v[u]); self.sum[v].set_v(path); } STTLabel::Compress => { let l = self.sum[self.stt.node[v].l.get().unwrap()].get_v(); let r = self.sum[self.stt.node[v].r.get().unwrap()].get_v(); let e = self.stt.node[v].e.get().unwrap(); let path = self.op.compress(l, r, &self.e[e]); self.sum[v].set_v(path); } } } } #[derive(Clone, Debug)] enum Union { V(V), E(E), } impl Union { fn set_v(&mut self, v: V) { *self = Self::V(v); } fn set_e(&mut self, e: E) { *self = Self::E(e); } fn get_v(&self) -> &V { let Union::V(ref v) = self else { unreachable!() }; v } fn get_e(&self) -> &E { let Union::E(ref v) = self else { unreachable!() }; v } } #[derive(Debug, Clone, Copy, Eq, PartialEq)] enum STTLabel { Vertex, AddEdge, Rake, AddVertex, Compress, } #[derive(Clone, Debug)] struct STTNode { p: Pointer, l: Pointer, r: Pointer, e: Pointer, } impl STTNode { fn new(l: Pointer, r: Pointer, e: Pointer) -> Self { Self { p: Pointer::null(), l, r, e, } } } pub struct StaticTopTree { label: Vec, node: Vec, height: Vec, vertex: Vec, edge: Vec, size: usize, } impl StaticTopTree { pub fn new(edge: Vec<(usize, usize)>, root: usize) -> Self { let size = edge.len() + 1; let mut graph = vec![vec![]; size]; for (i, &(a, b)) in edge.iter().enumerate() { graph[a].push((b, i)); graph[b].push((a, i)); } let mut topo = vec![root]; let mut parent = vec![(size, size); size]; let mut inv_edge = vec![size; size]; for i in 0..size { let v = topo[i]; for (u, k) in graph[v].clone() { graph[u].retain(|p| p.0 != v); parent[u] = (v, k); inv_edge[k] = u; topo.push(u); } } let mut s = vec![1i32; size]; for &v in topo.iter().rev() { let c = &mut graph[v]; for i in 1..c.len() { if s[c[i].0] > s[c[0].0] { c.swap(0, i); } } s[v] += c.iter().map(|e| s[e.0]).sum::(); } let mut stt = Self { label: vec![], node: vec![], height: vec![], vertex: vec![!0; size], edge: vec![!0; size - 1], size, }; let mut id = vec![!0; size]; for &v in topo.iter().rev() { if graph[v].len() <= 1 { id[v] = stt.append_inner(!0, !0, v, STTLabel::Vertex); } else { let mut array = vec![None; 64]; let mut bit = 0usize; for &(u, e) in graph[v][1..].iter() { let mut k = stt.append_inner(id[u], !0, e, STTLabel::AddEdge); let mut h = stt.height[k]; while let Some(x) = array[h].take() { bit ^= 1 << h; k = stt.append_inner(k, x, !0, STTLabel::Rake); h = stt.height[k]; } array[h] = Some(k); bit |= 1 << h; } let x = bit.trailing_zeros() as usize; let mut k = array[x].take().unwrap(); bit ^= 1 << x; while bit > 0 { let x = bit.trailing_zeros() as usize; let u = array[x].take().unwrap(); k = stt.append_inner(k, u, !0, STTLabel::Rake); bit ^= 1 << x; } id[v] = stt.append_inner(k, !0, v, STTLabel::AddVertex); } if v == root || graph[parent[v].0][0].0 != v { let mut stack = vec![(id[v], size)]; let mut pos = v; while let Some(&(u, k)) = graph[pos].get(0) { stack.push((id[u], k)); while stack.len() > 1 { let len = stack.len(); let (b, a) = (stack[len - 2], stack[len - 1]); if len >= 3 && stt.height[stack[len - 3].0] <= stt.height[a.0] { let c = stack[len - 3]; stack.truncate(len - 3); let v = stt.append_inner(c.0, b.0, b.1, STTLabel::Compress); stack.extend([(v, c.1), a].iter().cloned()); } else if stt.height[b.0] <= stt.height[a.0] { stack.truncate(len - 2); let v = stt.append_inner(b.0, a.0, a.1, STTLabel::Compress); stack.push((v, b.1)); } else { break; } } pos = u; } while stack.len() >= 2 { let a = stack.pop().unwrap(); let b = stack.pop().unwrap(); let v = stt.append_inner(b.0, a.0, a.1, STTLabel::Compress); stack.push((v, b.1)); } id[v] = stack.pop().unwrap().0; } } stt } fn append_inner(&mut self, l: usize, r: usize, e: usize, label: STTLabel) -> usize { let v = self.node.len(); let mut h = 0; let lp = if let Some(n) = self.node.get_mut(l) { n.p.set(v); h = std::cmp::max(h, self.height[l]); Pointer::new(l) } else { Pointer::null() }; let rp = if let Some(n) = self.node.get_mut(r) { n.p.set(v); h = std::cmp::max(h, self.height[r]); Pointer::new(r) } else { Pointer::null() }; if self.node.get(r).is_some() { h += 1; } let ep = if e < self.size { if label == STTLabel::Vertex || label == STTLabel::AddVertex { self.vertex[e] = v; } else if label == STTLabel::AddEdge || label == STTLabel::Compress { self.edge[e] = v; } else { unreachable!(); } Pointer::new(e) } else { Pointer::null() }; self.label.push(label); self.node.push(STTNode::new(lp, rp, ep)); self.height.push(h); v } } // ---------- begin pointer ---------- #[derive(Clone, Copy)] pub struct Pointer(u32); impl Pointer { pub fn new(v: usize) -> Self { Self(v as u32) } pub fn null() -> Self { Self(!0) } pub fn get(&self) -> Option { if self.0 == !0 { None } else { Some(self.0 as usize) } } pub fn is_null(&self) -> bool { self.get().is_none() } pub fn set(&mut self, v: usize) { self.0 = v as u32 } } impl From for Pointer { fn from(x: usize) -> Self { Self::new(x) } } impl Default for Pointer { fn default() -> Self { Self::null() } } impl std::fmt::Debug for Pointer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if let Some(x) = self.get() { write!(f, "{}", x) } else { write!(f, "null") } } } // ---------- end pointer ---------- mod util { pub trait Join { fn join(self, sep: &str) -> String; } impl Join for I where I: Iterator, T: std::fmt::Display, { fn join(self, sep: &str) -> String { let mut s = String::new(); use std::fmt::*; for (i, v) in self.enumerate() { if i > 0 { write!(&mut s, "{}", sep).ok(); } write!(&mut s, "{}", v).ok(); } s } } } // ---------- begin modint ---------- pub const fn pow_mod(mut r: u32, mut n: u32, m: u32) -> u32 { let mut t = 1; while n > 0 { if n & 1 == 1 { t = (t as u64 * r as u64 % m as u64) as u32; } r = (r as u64 * r as u64 % m as u64) as u32; n >>= 1; } t } pub const fn primitive_root(p: u32) -> u32 { let mut m = p - 1; let mut f = [1; 30]; let mut k = 0; let mut d = 2; while d * d <= m { if m % d == 0 { f[k] = d; k += 1; } while m % d == 0 { m /= d; } d += 1; } if m > 1 { f[k] = m; k += 1; } let mut g = 1; while g < p { let mut ok = true; let mut i = 0; while i < k { ok &= pow_mod(g, (p - 1) / f[i], p) > 1; i += 1; } if ok { break; } g += 1; } g } pub const fn is_prime(n: u32) -> bool { if n <= 1 { return false; } let mut d = 2; while d * d <= n { if n % d == 0 { return false; } d += 1; } true } #[derive(Clone, Copy, PartialEq, Eq)] pub struct ModInt(u32); impl ModInt<{ M }> { const REM: u32 = { let mut t = 1u32; let mut s = !M + 1; let mut n = !0u32 >> 2; while n > 0 { if n & 1 == 1 { t = t.wrapping_mul(s); } s = s.wrapping_mul(s); n >>= 1; } t }; const INI: u64 = ((1u128 << 64) % M as u128) as u64; const IS_PRIME: () = assert!(is_prime(M)); const PRIMITIVE_ROOT: u32 = primitive_root(M); const ORDER: usize = 1 << (M - 1).trailing_zeros(); const fn reduce(x: u64) -> u32 { let _ = Self::IS_PRIME; let b = (x as u32 * Self::REM) as u64; let t = x + b * M as u64; let mut c = (t >> 32) as u32; if c >= M { c -= M; } c as u32 } const fn multiply(a: u32, b: u32) -> u32 { Self::reduce(a as u64 * b as u64) } pub const fn new(v: u32) -> Self { assert!(v < M); Self(Self::reduce(v as u64 * Self::INI)) } pub const fn const_mul(&self, rhs: Self) -> Self { Self(Self::multiply(self.0, rhs.0)) } pub const fn pow(&self, mut n: u64) -> Self { let mut t = Self::new(1); let mut r = *self; while n > 0 { if n & 1 == 1 { t = t.const_mul(r); } r = r.const_mul(r); n >>= 1; } t } pub const fn inv(&self) -> Self { assert!(self.0 != 0); self.pow(M as u64 - 2) } pub const fn get(&self) -> u32 { Self::reduce(self.0 as u64) } pub const fn zero() -> Self { Self::new(0) } pub const fn one() -> Self { Self::new(1) } } impl Add for ModInt<{ M }> { type Output = Self; fn add(self, rhs: Self) -> Self::Output { let mut v = self.0 + rhs.0; if v >= M { v -= M; } Self(v) } } impl Sub for ModInt<{ M }> { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { let mut v = self.0 - rhs.0; if self.0 < rhs.0 { v += M; } Self(v) } } impl Mul for ModInt<{ M }> { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { self.const_mul(rhs) } } impl Div for ModInt<{ M }> { type Output = Self; fn div(self, rhs: Self) -> Self::Output { self * rhs.inv() } } impl AddAssign for ModInt<{ M }> { fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } } impl SubAssign for ModInt<{ M }> { fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } } impl MulAssign for ModInt<{ M }> { fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } } impl DivAssign for ModInt<{ M }> { fn div_assign(&mut self, rhs: Self) { *self = *self / rhs; } } impl Neg for ModInt<{ M }> { type Output = Self; fn neg(self) -> Self::Output { if self.0 == 0 { self } else { Self(M - self.0) } } } impl std::fmt::Display for ModInt<{ M }> { fn fmt<'a>(&self, f: &mut std::fmt::Formatter<'a>) -> std::fmt::Result { write!(f, "{}", self.get()) } } impl std::fmt::Debug for ModInt<{ M }> { fn fmt<'a>(&self, f: &mut std::fmt::Formatter<'a>) -> std::fmt::Result { write!(f, "{}", self.get()) } } impl std::str::FromStr for ModInt<{ M }> { type Err = std::num::ParseIntError; fn from_str(s: &str) -> Result { let val = s.parse::()?; Ok(ModInt::new(val)) } } impl From for ModInt<{ M }> { fn from(val: usize) -> ModInt<{ M }> { ModInt::new((val % M as usize) as u32) } } // ---------- end modint ---------- // ---------- begin precalc ---------- pub struct Precalc { fact: Vec>, ifact: Vec>, inv: Vec>, } impl Precalc { pub fn new(size: usize) -> Self { let mut fact = vec![ModInt::one(); size + 1]; let mut ifact = vec![ModInt::one(); size + 1]; let mut inv = vec![ModInt::one(); size + 1]; for i in 2..=size { fact[i] = fact[i - 1] * ModInt::from(i); } ifact[size] = fact[size].inv(); for i in (2..=size).rev() { inv[i] = ifact[i] * fact[i - 1]; ifact[i - 1] = ifact[i] * ModInt::from(i); } Self { fact, ifact, inv } } pub fn fact(&self, n: usize) -> ModInt { self.fact[n] } pub fn ifact(&self, n: usize) -> ModInt { self.ifact[n] } pub fn inv(&self, n: usize) -> ModInt { assert!(0 < n); self.inv[n] } pub fn perm(&self, n: usize, k: usize) -> ModInt { if k > n { return ModInt::zero(); } self.fact[n] * self.ifact[n - k] } pub fn binom(&self, n: usize, k: usize) -> ModInt { if n < k { return ModInt::zero(); } self.fact[n] * self.ifact[k] * self.ifact[n - k] } } // ---------- end precalc ---------- impl Zero for ModInt<{ M }> { fn zero() -> Self { Self::zero() } fn is_zero(&self) -> bool { self.0 == 0 } } impl One for ModInt<{ M }> { fn one() -> Self { Self::one() } fn is_one(&self) -> bool { self.get() == 1 } } // ---------- begin array op ---------- struct NTTPrecalc { sum_e: [ModInt<{ M }>; 30], sum_ie: [ModInt<{ M }>; 30], } impl NTTPrecalc<{ M }> { const fn new() -> Self { let cnt2 = (M - 1).trailing_zeros() as usize; let root = ModInt::new(ModInt::<{ M }>::PRIMITIVE_ROOT); let zeta = root.pow((M - 1) as u64 >> cnt2); let mut es = [ModInt::zero(); 30]; let mut ies = [ModInt::zero(); 30]; let mut sum_e = [ModInt::zero(); 30]; let mut sum_ie = [ModInt::zero(); 30]; let mut e = zeta; let mut ie = e.inv(); let mut i = cnt2; while i >= 2 { es[i - 2] = e; ies[i - 2] = ie; e = e.const_mul(e); ie = ie.const_mul(ie); i -= 1; } let mut now = ModInt::one(); let mut inow = ModInt::one(); let mut i = 0; while i < cnt2 - 1 { sum_e[i] = es[i].const_mul(now); sum_ie[i] = ies[i].const_mul(inow); now = ies[i].const_mul(now); inow = es[i].const_mul(inow); i += 1; } Self { sum_e, sum_ie } } } struct NTTPrecalcHelper; impl NTTPrecalcHelper { const A: NTTPrecalc = NTTPrecalc::new(); } pub trait ArrayAdd { type Item; fn add(&self, rhs: &[Self::Item]) -> Vec; } impl ArrayAdd for [T] where T: Zero + Copy, { type Item = T; fn add(&self, rhs: &[Self::Item]) -> Vec { let mut c = vec![T::zero(); self.len().max(rhs.len())]; c[..self.len()].copy_from_slice(self); c.add_assign(rhs); c } } pub trait ArrayAddAssign { type Item; fn add_assign(&mut self, rhs: &[Self::Item]); } impl ArrayAddAssign for [T] where T: Add + Copy, { type Item = T; fn add_assign(&mut self, rhs: &[Self::Item]) { assert!(self.len() >= rhs.len()); self.iter_mut().zip(rhs).for_each(|(x, a)| *x = *x + *a); } } impl ArrayAddAssign for Vec where T: Zero + Add + Copy, { type Item = T; fn add_assign(&mut self, rhs: &[Self::Item]) { if self.len() < rhs.len() { self.resize(rhs.len(), T::zero()); } self.as_mut_slice().add_assign(rhs); } } pub trait ArraySub { type Item; fn sub(&self, rhs: &[Self::Item]) -> Vec; } impl ArraySub for [T] where T: Zero + Sub + Copy, { type Item = T; fn sub(&self, rhs: &[Self::Item]) -> Vec { let mut c = vec![T::zero(); self.len().max(rhs.len())]; c[..self.len()].copy_from_slice(self); c.sub_assign(rhs); c } } pub trait ArraySubAssign { type Item; fn sub_assign(&mut self, rhs: &[Self::Item]); } impl ArraySubAssign for [T] where T: Sub + Copy, { type Item = T; fn sub_assign(&mut self, rhs: &[Self::Item]) { assert!(self.len() >= rhs.len()); self.iter_mut().zip(rhs).for_each(|(x, a)| *x = *x - *a); } } impl ArraySubAssign for Vec where T: Zero + Sub + Copy, { type Item = T; fn sub_assign(&mut self, rhs: &[Self::Item]) { if self.len() < rhs.len() { self.resize(rhs.len(), T::zero()); } self.as_mut_slice().sub_assign(rhs); } } pub trait ArrayDot { type Item; fn dot(&self, rhs: &[Self::Item]) -> Vec; } impl ArrayDot for [T] where T: Mul + Copy, { type Item = T; fn dot(&self, rhs: &[Self::Item]) -> Vec { assert!(self.len() == rhs.len()); self.iter().zip(rhs).map(|p| *p.0 * *p.1).collect() } } pub trait ArrayDotAssign { type Item; fn dot_assign(&mut self, rhs: &[Self::Item]); } impl ArrayDotAssign for [T] where T: MulAssign + Copy, { type Item = T; fn dot_assign(&mut self, rhs: &[Self::Item]) { assert!(self.len() == rhs.len()); self.iter_mut().zip(rhs).for_each(|(x, a)| *x *= *a); } } pub trait ArrayMul { type Item; fn mul(&self, rhs: &[Self::Item]) -> Vec; } impl ArrayMul for [T] where T: Zero + One + Copy, { type Item = T; fn mul(&self, rhs: &[Self::Item]) -> Vec { if self.is_empty() || rhs.is_empty() { return vec![]; } let mut res = vec![T::zero(); self.len() + rhs.len() - 1]; for (i, a) in self.iter().enumerate() { for (res, b) in res[i..].iter_mut().zip(rhs.iter()) { *res = *res + *a * *b; } } res } } // transform でlen=1を指定すればNTTになる pub trait ArrayConvolution { type Item; fn transform(&mut self, len: usize); fn inverse_transform(&mut self, len: usize); fn convolution(&self, rhs: &[Self::Item]) -> Vec; } impl ArrayConvolution for [ModInt<{ M }>] { type Item = ModInt<{ M }>; fn transform(&mut self, len: usize) { let f = self; let n = f.len(); let k = (n / len).trailing_zeros() as usize; assert!(len << k == n); assert!(k <= ModInt::<{ M }>::ORDER); let pre = &NTTPrecalcHelper::<{ M }>::A; for ph in 1..=k { let p = len << (k - ph); let mut now = ModInt::one(); for (i, f) in f.chunks_exact_mut(2 * p).enumerate() { let (x, y) = f.split_at_mut(p); for (x, y) in x.iter_mut().zip(y.iter_mut()) { let l = *x; let r = *y * now; *x = l + r; *y = l - r; } now *= pre.sum_e[(!i).trailing_zeros() as usize]; } } } fn inverse_transform(&mut self, len: usize) { let f = self; let n = f.len(); let k = (n / len).trailing_zeros() as usize; assert!(len << k == n); assert!(k <= ModInt::<{ M }>::ORDER); let pre = &NTTPrecalcHelper::<{ M }>::A; for ph in (1..=k).rev() { let p = len << (k - ph); let mut inow = ModInt::one(); for (i, f) in f.chunks_exact_mut(2 * p).enumerate() { let (x, y) = f.split_at_mut(p); for (x, y) in x.iter_mut().zip(y.iter_mut()) { let l = *x; let r = *y; *x = l + r; *y = (l - r) * inow; } inow *= pre.sum_ie[(!i).trailing_zeros() as usize]; } } let ik = ModInt::new(2).inv().pow(k as u64); for f in f.iter_mut() { *f *= ik; } } fn convolution(&self, rhs: &[Self::Item]) -> Vec { if self.len().min(rhs.len()) <= 32 { return self.mul(rhs); } const PARAM: usize = 10; let size = self.len() + rhs.len() - 1; let mut k = 0; while (size + (1 << k) - 1) >> k > PARAM { k += 1; } let len = (size + (1 << k) - 1) >> k; let mut f = vec![ModInt::zero(); len << k]; let mut g = vec![ModInt::zero(); len << k]; f[..self.len()].copy_from_slice(self); g[..rhs.len()].copy_from_slice(rhs); f.transform(len); g.transform(len); let mut buf = [ModInt::zero(); 2 * PARAM - 1]; let buf = &mut buf[..(2 * len - 1)]; let pre = &NTTPrecalcHelper::<{ M }>::A; let mut now = ModInt::one(); for (i, (f, g)) in f .chunks_exact_mut(2 * len) .zip(g.chunks_exact(2 * len)) .enumerate() { let mut r = now; for (f, g) in f.chunks_exact_mut(len).zip(g.chunks_exact(len)) { buf.fill(ModInt::zero()); for (i, f) in f.iter().enumerate() { for (buf, g) in buf[i..].iter_mut().zip(g.iter()) { *buf = *buf + *f * *g; } } f.copy_from_slice(&buf[..len]); for (f, buf) in f.iter_mut().zip(buf[len..].iter()) { *f = *f + r * *buf; } r = -r; } now *= pre.sum_e[(!i).trailing_zeros() as usize]; } f.inverse_transform(len); f.truncate(self.len() + rhs.len() - 1); f } } // ---------- end array op ---------- // ---------- begin trait ---------- use std::ops::*; pub trait Zero: Sized + Add { fn zero() -> Self; fn is_zero(&self) -> bool; } pub trait One: Sized + Mul { fn one() -> Self; fn is_one(&self) -> bool; } pub trait Group: Zero + Sub + Neg {} pub trait SemiRing: Zero + One {} pub trait Ring: SemiRing + Group {} pub trait Field: Ring + Div {} impl Group for T where T: Zero + Sub + Neg {} impl SemiRing for T where T: Zero + One {} impl Ring for T where T: SemiRing + Group {} impl Field for T where T: Ring + Div {} pub fn zero() -> T { T::zero() } pub fn one() -> T { T::one() } pub fn pow(mut r: T, mut n: usize) -> T { let mut t = one(); while n > 0 { if n & 1 == 1 { t = t * r.clone(); } r = r.clone() * r; n >>= 1; } t } pub fn pow_sum(mut r: T, mut n: usize) -> T { let mut ans = T::zero(); let mut sum = T::one(); while n > 0 { if n & 1 == 1 { ans = ans * r.clone() + sum.clone(); } sum = sum * (T::one() + r.clone()); r = r.clone() * r; n >>= 1; } ans } // ---------- end trait ---------- #[derive(Clone, Copy, Default, Debug)] pub struct Dual(T, T); impl Dual { pub fn new(a: T, b: T) -> Self { Self(a, b) } } impl Zero for Dual where T: Zero, { fn zero() -> Self { Self::new(T::zero(), T::zero()) } fn is_zero(&self) -> bool { self.0.is_zero() && self.1.is_zero() } } impl One for Dual where T: One + Zero + Clone, { fn one() -> Self { Self::new(T::one(), T::zero()) } fn is_one(&self) -> bool { self.0.is_one() && self.1.is_zero() } } impl Add for Dual where T: Add, { type Output = Self; fn add(self, rhs: Self) -> Self { Self::new(self.0 + rhs.0, self.1 + rhs.1) } } impl AddAssign for Dual where T: Add + Clone, { fn add_assign(&mut self, rhs: Self) { *self = self.clone() + rhs; } } impl Sub for Dual where T: Sub, { type Output = Self; fn sub(self, rhs: Self) -> Self { Self::new(self.0 - rhs.0, self.1 - rhs.1) } } impl SubAssign for Dual where T: Sub + Clone, { fn sub_assign(&mut self, rhs: Self) { *self = self.clone() - rhs; } } impl Mul for Dual where T: Clone + Add + Mul, { type Output = Self; fn mul(self, rhs: Self) -> Self { Self::new( self.0.clone() * rhs.0.clone(), self.0 * rhs.1 + self.1 * rhs.0, ) } } impl MulAssign for Dual where T: Clone + Add + Mul, { fn mul_assign(&mut self, rhs: Self) { *self = self.clone() * rhs; } }