結果
| 問題 | No.2116 Making Forest Hard |
| コンテスト | |
| ユーザー |
akakimidori
|
| 提出日時 | 2026-03-21 20:44:00 |
| 言語 | Rust (1.93.0 + proconio + num + itertools) |
| 結果 |
AC
|
| 実行時間 | 302 ms / 8,000 ms |
| コード長 | 39,317 bytes |
| 記録 | |
| コンパイル時間 | 7,065 ms |
| コンパイル使用メモリ | 231,000 KB |
| 実行使用メモリ | 47,516 KB |
| 最終ジャッジ日時 | 2026-03-21 20:44:26 |
| 合計ジャッジ時間 | 18,893 ms |
|
ジャッジサーバーID (参考情報) |
judge2_0 / judge1_1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 53 |
コンパイルメッセージ
warning: unused variable: `e`
--> src/main.rs:50:40
|
50 | fn add_edge(&self, p: &Self::Path, e: &Self::Edge) -> Self::Point {
| ^ help: if this is intentional, prefix it with an underscore: `_e`
|
= note: `#[warn(unused_variables)]` (part of `#[warn(unused)]`) on by default
warning: unused variable: `e`
--> src/main.rs:77:56
|
77 | fn compress(&self, p: &Self::Path, c: &Self::Path, e: &Self::Edge) -> Self::Path {
| ^ help: if this is intentional, prefix it with an underscore: `_e`
warning: trait `Join` is never used
--> src/main.rs:625:15
|
625 | pub trait Join {
| ^^^^
|
= note: `#[warn(dead_code)]` (part of `#[warn(unused)]`) on by default
ソースコード
fn main() {
let (a, e) = read();
let n = a.len();
let mut solver = RerootingDP::new(
R,
vec![0; n],
e.into_iter().map(|e| (e.0 - 1, e.1 - 1, ())).collect(),
);
let mut ord = (0..n).collect::<Vec<_>>();
ord.sort_by_key(|v| a[*v]);
let mut ans = M::zero();
for x in ord {
solver.set_vertex(x, 1);
let p = solver.find(x);
ans += M::new(a[x]) * (p.con + p.dis).1;
}
println!("{}", ans);
}
type M = ModInt<998244353>;
#[derive(Clone, Debug)]
struct Data {
con: Dual<M>,
dis: Dual<M>,
invalid: M,
}
struct R;
impl TreeDP for R {
type Vertex = u8;
type Edge = ();
type Path = Data;
type Point = (Dual<M>, M);
fn vertex(&self, v: &Self::Vertex) -> Self::Path {
if *v == 0 {
Data {
con: Dual::zero(),
dis: Dual::zero(),
invalid: M::one(),
}
} else {
Data {
con: Dual::new(M::one(), M::one()),
dis: Dual::zero(),
invalid: M::zero(),
}
}
}
fn add_edge(&self, p: &Self::Path, e: &Self::Edge) -> Self::Point {
let invalid = p.invalid;
let mut s = p.con + p.dis;
s.0 = s.0 + s.0 + p.invalid;
(s, invalid)
}
fn rake(&self, a: &Self::Point, b: &Self::Point) -> Self::Point {
let (x, y) = *a;
let (z, w) = *b;
(x * z, (x.0 + y) * (z.0 + w) - x.0 * z.0)
}
fn add_vertex(&self, p: &Self::Point, v: &Self::Vertex) -> Self::Path {
let (a, b) = *p;
if *v == 0 {
Data {
con: Dual::zero(),
dis: Dual::zero(),
invalid: a.0 + b,
}
} else {
Data {
con: Dual::new(a.0, a.1 + a.0),
dis: Dual::zero(),
invalid: b,
}
}
}
fn compress(&self, p: &Self::Path, c: &Self::Path, e: &Self::Edge) -> Self::Path {
let con = p.con * c.con;
let mut dis = (p.con + p.dis) * Dual::new((c.con + c.dis).0 + c.invalid, M::zero());
dis = dis + p.con * c.dis;
dis = dis + p.dis * Dual::new((c.con + c.dis).0 + c.invalid, M::zero());
let mut invalid = p.invalid * ((c.con + c.dis).0 + c.invalid) * M::new(2);
invalid += p.con.0 * c.invalid;
Data {
con,
dis,
invalid,
}
}
}
fn read() -> (Vec<u32>, Vec<(usize, usize)>) {
let mut s = String::new();
use std::io::*;
std::io::stdin().read_to_string(&mut s).unwrap();
let mut it = s.trim().split_whitespace().flat_map(|s| s.parse::<usize>());
let mut next = || it.next().unwrap();
let n = next();
let a = (0..n).map(|_| next() as u32).collect();
let e = (1..n)
.map(|_| {
let a = next();
let b = next();
(a, b)
})
.collect();
(a, e)
}
pub trait TreeDP {
type Vertex: Clone;
type Edge: Clone;
type Path: Clone;
type Point: Clone;
fn vertex(&self, v: &Self::Vertex) -> Self::Path;
fn add_edge(&self, p: &Self::Path, e: &Self::Edge) -> Self::Point;
fn rake(&self, a: &Self::Point, b: &Self::Point) -> Self::Point;
fn add_vertex(&self, p: &Self::Point, v: &Self::Vertex) -> Self::Path;
fn compress(&self, p: &Self::Path, c: &Self::Path, e: &Self::Edge) -> Self::Path;
}
pub struct RerootingDP<R, V, E, A, B> {
op: R,
v: Vec<V>,
e: Vec<E>,
sum: Vec<Union<(A, A), B>>,
stt: StaticTopTree,
}
impl<R> RerootingDP<R, R::Vertex, R::Edge, R::Path, R::Point>
where
R: TreeDP,
{
const ROOT: usize = 0;
pub fn new(op: R, v: Vec<R::Vertex>, edge: Vec<(usize, usize, R::Edge)>) -> Self {
assert!(v.len() == edge.len() + 1);
let mut e = vec![];
let mut memo = vec![];
for (a, b, w) in edge {
e.push(w);
memo.push((a, b));
}
let stt = StaticTopTree::new(memo, Self::ROOT);
let sum = vec![Union::V((op.vertex(&v[0]), op.vertex(&v[0]))); stt.label.len()];
let mut res = Self { op, v, e, sum, stt };
for i in 0..res.stt.label.len() {
res.pull(i);
}
res
}
pub fn set_vertex(&mut self, v: usize, w: R::Vertex) {
self.v[v] = w;
self.update(self.stt.vertex[v]);
}
pub fn set_edge(&mut self, e: usize, w: R::Edge) {
self.e[e] = w;
self.update(self.stt.edge[e]);
}
pub fn find(&self, root: usize) -> R::Path {
if root == Self::ROOT {
return self.sum.last().unwrap().get_v().0.clone();
}
// なんか非常に汚い、もっと綺麗に書けないか
let mut pos = self.stt.vertex[root];
let mut memo = vec![];
while let Some(p) = self.stt.node[pos].p.get() {
let l = self.stt.node[p].l.get().unwrap() == pos;
pos = p;
memo.push((p, l));
}
let mut up: Option<(R::Path, R::Edge)> = None;
let mut down: Option<(R::Path, R::Edge)> = None;
let mut point: Option<R::Point> = None;
let mut vertex: Option<R::Vertex> = None;
for &(pos, left) in memo.iter().rev() {
if self.stt.label[pos] == STTLabel::Compress {
let e = &self.e[self.stt.node[pos].e.get().unwrap()];
if left {
let r = self.stt.node[pos].r.get().unwrap();
let r = &self.sum[r].get_v().0;
down = Some(down.map_or((r.clone(), e.clone()), |(a, b)| {
(self.op.compress(r, &a, &b), e.clone())
}));
} else {
let l = self.stt.node[pos].l.get().unwrap();
let l = &self.sum[l].get_v().1;
up = Some(up.map_or((l.clone(), e.clone()), |(a, b)| {
(self.op.compress(l, &a, &b), e.clone())
}));
}
} else if self.stt.label[pos] == STTLabel::AddVertex {
vertex = Some(self.v[self.stt.node[pos].e.get().unwrap()].clone());
let u = up.take().map(|p| self.op.add_edge(&p.0, &p.1));
let d = down.take().map(|p| self.op.add_edge(&p.0, &p.1));
point = match (u, d) {
(Some(a), Some(b)) => Some(self.op.rake(&a, &b)),
(a, b) => a.or(b),
};
} else if self.stt.label[pos] == STTLabel::Rake {
let other = if left {
self.stt.node[pos].r
} else {
self.stt.node[pos].l
}
.get()
.unwrap();
let p = self.sum[other].get_e();
point = Some(point.map_or(p.clone(), |q| self.op.rake(p, &q)));
} else if self.stt.label[pos] == STTLabel::AddEdge {
let e = &self.e[self.stt.node[pos].e.get().unwrap()];
if point.is_some() {
let p = point.take().unwrap();
let v = vertex.take().unwrap();
up = Some((self.op.add_vertex(&p, &v), e.clone()));
} else {
unreachable!()
}
} else {
unreachable!()
}
}
let pos = self.stt.vertex[root];
if self.stt.label[pos] == STTLabel::AddVertex {
let u = up.map(|p| self.op.add_edge(&p.0, &p.1));
let d = down.map(|p| self.op.add_edge(&p.0, &p.1));
let p = match (u, d) {
(Some(a), Some(b)) => Some(self.op.rake(&a, &b)),
(a, b) => a.or(b),
}
.unwrap();
let c = self.sum[self.stt.node[pos].l.get().unwrap()].get_e();
let q = self.op.rake(&p, c);
self.op.add_vertex(&q, &self.v[root])
} else {
let u = up.map(|p| self.op.add_edge(&p.0, &p.1));
let d = down.map(|p| self.op.add_edge(&p.0, &p.1));
let p = match (u, d) {
(Some(a), Some(b)) => Some(self.op.rake(&a, &b)),
(a, b) => a.or(b),
}
.unwrap();
self.op.add_vertex(&p, &self.v[root])
}
}
fn update(&mut self, mut v: usize) {
self.pull(v);
while let Some(p) = self.stt.node[v].p.get() {
v = p;
self.pull(p);
}
}
fn pull(&mut self, v: usize) {
match self.stt.label[v] {
STTLabel::Vertex => {
let u = self.stt.node[v].e.get().unwrap();
let p = self.op.vertex(&self.v[u]);
self.sum[v].set_v((p.clone(), p));
}
STTLabel::AddEdge => {
let l = self.stt.node[v].l.get().unwrap();
let e = self.stt.node[v].e.get().unwrap();
let path = &self.sum[l].get_v().0;
let point = self.op.add_edge(path, &self.e[e]);
self.sum[v].set_e(point);
}
STTLabel::Rake => {
let l = self.stt.node[v].l.get().unwrap();
let r = self.stt.node[v].r.get().unwrap();
let point = self.op.rake(self.sum[l].get_e(), self.sum[r].get_e());
self.sum[v].set_e(point);
}
STTLabel::AddVertex => {
let l = self.stt.node[v].l.get().unwrap();
let u = self.stt.node[v].e.get().unwrap();
let path = self.op.add_vertex(self.sum[l].get_e(), &self.v[u]);
self.sum[v].set_v((path.clone(), path));
}
STTLabel::Compress => {
let l = self.sum[self.stt.node[v].l.get().unwrap()].get_v();
let r = self.sum[self.stt.node[v].r.get().unwrap()].get_v();
let e = self.stt.node[v].e.get().unwrap();
let lr = self.op.compress(&l.0, &r.0, &self.e[e]);
let rl = self.op.compress(&r.1, &l.1, &self.e[e]);
self.sum[v].set_v((lr, rl));
}
}
}
}
pub struct FixRootTreeDP<R, V, E, A, B> {
op: R,
v: Vec<V>,
e: Vec<E>,
sum: Vec<Union<A, B>>,
stt: StaticTopTree,
}
impl<R> FixRootTreeDP<R, R::Vertex, R::Edge, R::Path, R::Point>
where
R: TreeDP,
{
pub fn new(op: R, v: Vec<R::Vertex>, edge: Vec<(usize, usize, R::Edge)>) -> Self {
assert!(v.len() == edge.len() + 1);
let mut e = vec![];
let mut memo = vec![];
for (a, b, w) in edge {
e.push(w);
memo.push((a, b));
}
let stt = StaticTopTree::new(memo, 0);
let sum = vec![Union::V(op.vertex(&v[0])); stt.label.len()];
let mut res = Self { op, v, e, sum, stt };
for i in 0..res.stt.label.len() {
res.pull(i);
}
res
}
pub fn set_vertex(&mut self, v: usize, w: R::Vertex) {
self.v[v] = w;
self.update(self.stt.vertex[v]);
}
pub fn set_edge(&mut self, e: usize, w: R::Edge) {
self.e[e] = w;
self.update(self.stt.edge[e]);
}
pub fn find(&self) -> R::Path {
self.sum.last().unwrap().get_v().clone()
}
fn update(&mut self, mut v: usize) {
self.pull(v);
while let Some(p) = self.stt.node[v].p.get() {
v = p;
self.pull(p);
}
}
fn pull(&mut self, v: usize) {
match self.stt.label[v] {
STTLabel::Vertex => {
let u = self.stt.node[v].e.get().unwrap();
self.sum[v].set_v(self.op.vertex(&self.v[u]));
}
STTLabel::AddEdge => {
let l = self.stt.node[v].l.get().unwrap();
let e = self.stt.node[v].e.get().unwrap();
let path = self.sum[l].get_v();
let point = self.op.add_edge(path, &self.e[e]);
self.sum[v].set_e(point);
}
STTLabel::Rake => {
let l = self.stt.node[v].l.get().unwrap();
let r = self.stt.node[v].r.get().unwrap();
let point = self.op.rake(self.sum[l].get_e(), self.sum[r].get_e());
self.sum[v].set_e(point);
}
STTLabel::AddVertex => {
let l = self.stt.node[v].l.get().unwrap();
let u = self.stt.node[v].e.get().unwrap();
let path = self.op.add_vertex(self.sum[l].get_e(), &self.v[u]);
self.sum[v].set_v(path);
}
STTLabel::Compress => {
let l = self.sum[self.stt.node[v].l.get().unwrap()].get_v();
let r = self.sum[self.stt.node[v].r.get().unwrap()].get_v();
let e = self.stt.node[v].e.get().unwrap();
let path = self.op.compress(l, r, &self.e[e]);
self.sum[v].set_v(path);
}
}
}
}
#[derive(Clone, Debug)]
enum Union<V, E> {
V(V),
E(E),
}
impl<V, E> Union<V, E> {
fn set_v(&mut self, v: V) {
*self = Self::V(v);
}
fn set_e(&mut self, e: E) {
*self = Self::E(e);
}
fn get_v(&self) -> &V {
let Union::V(ref v) = self else {
unreachable!()
};
v
}
fn get_e(&self) -> &E {
let Union::E(ref v) = self else {
unreachable!()
};
v
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum STTLabel {
Vertex,
AddEdge,
Rake,
AddVertex,
Compress,
}
#[derive(Clone, Debug)]
struct STTNode {
p: Pointer,
l: Pointer,
r: Pointer,
e: Pointer,
}
impl STTNode {
fn new(l: Pointer, r: Pointer, e: Pointer) -> Self {
Self {
p: Pointer::null(),
l,
r,
e,
}
}
}
pub struct StaticTopTree {
label: Vec<STTLabel>,
node: Vec<STTNode>,
height: Vec<usize>,
vertex: Vec<usize>,
edge: Vec<usize>,
size: usize,
}
impl StaticTopTree {
pub fn new(edge: Vec<(usize, usize)>, root: usize) -> Self {
let size = edge.len() + 1;
let mut graph = vec![vec![]; size];
for (i, &(a, b)) in edge.iter().enumerate() {
graph[a].push((b, i));
graph[b].push((a, i));
}
let mut topo = vec![root];
let mut parent = vec![(size, size); size];
let mut inv_edge = vec![size; size];
for i in 0..size {
let v = topo[i];
for (u, k) in graph[v].clone() {
graph[u].retain(|p| p.0 != v);
parent[u] = (v, k);
inv_edge[k] = u;
topo.push(u);
}
}
let mut s = vec![1i32; size];
for &v in topo.iter().rev() {
let c = &mut graph[v];
for i in 1..c.len() {
if s[c[i].0] > s[c[0].0] {
c.swap(0, i);
}
}
s[v] += c.iter().map(|e| s[e.0]).sum::<i32>();
}
let mut stt = Self {
label: vec![],
node: vec![],
height: vec![],
vertex: vec![!0; size],
edge: vec![!0; size - 1],
size,
};
let mut id = vec![!0; size];
for &v in topo.iter().rev() {
if graph[v].len() <= 1 {
id[v] = stt.append_inner(!0, !0, v, STTLabel::Vertex);
} else {
let mut array = vec![None; 64];
let mut bit = 0usize;
for &(u, e) in graph[v][1..].iter() {
let mut k = stt.append_inner(id[u], !0, e, STTLabel::AddEdge);
let mut h = stt.height[k];
while let Some(x) = array[h].take() {
bit ^= 1 << h;
k = stt.append_inner(k, x, !0, STTLabel::Rake);
h = stt.height[k];
}
array[h] = Some(k);
bit |= 1 << h;
}
let x = bit.trailing_zeros() as usize;
let mut k = array[x].take().unwrap();
bit ^= 1 << x;
while bit > 0 {
let x = bit.trailing_zeros() as usize;
let u = array[x].take().unwrap();
k = stt.append_inner(k, u, !0, STTLabel::Rake);
bit ^= 1 << x;
}
id[v] = stt.append_inner(k, !0, v, STTLabel::AddVertex);
}
if v == root || graph[parent[v].0][0].0 != v {
let mut stack = vec![(id[v], size)];
let mut pos = v;
while let Some(&(u, k)) = graph[pos].get(0) {
stack.push((id[u], k));
while stack.len() > 1 {
let len = stack.len();
let (b, a) = (stack[len - 2], stack[len - 1]);
if len >= 3 && stt.height[stack[len - 3].0] <= stt.height[a.0] {
let c = stack[len - 3];
stack.truncate(len - 3);
let v = stt.append_inner(c.0, b.0, b.1, STTLabel::Compress);
stack.extend([(v, c.1), a].iter().cloned());
} else if stt.height[b.0] <= stt.height[a.0] {
stack.truncate(len - 2);
let v = stt.append_inner(b.0, a.0, a.1, STTLabel::Compress);
stack.push((v, b.1));
} else {
break;
}
}
pos = u;
}
while stack.len() >= 2 {
let a = stack.pop().unwrap();
let b = stack.pop().unwrap();
let v = stt.append_inner(b.0, a.0, a.1, STTLabel::Compress);
stack.push((v, b.1));
}
id[v] = stack.pop().unwrap().0;
}
}
stt
}
fn append_inner(&mut self, l: usize, r: usize, e: usize, label: STTLabel) -> usize {
let v = self.node.len();
let mut h = 0;
let lp = if let Some(n) = self.node.get_mut(l) {
n.p.set(v);
h = std::cmp::max(h, self.height[l]);
Pointer::new(l)
} else {
Pointer::null()
};
let rp = if let Some(n) = self.node.get_mut(r) {
n.p.set(v);
h = std::cmp::max(h, self.height[r]);
Pointer::new(r)
} else {
Pointer::null()
};
if self.node.get(r).is_some() {
h += 1;
}
let ep = if e < self.size {
if label == STTLabel::Vertex || label == STTLabel::AddVertex {
self.vertex[e] = v;
} else if label == STTLabel::AddEdge || label == STTLabel::Compress {
self.edge[e] = v;
} else {
unreachable!();
}
Pointer::new(e)
} else {
Pointer::null()
};
self.label.push(label);
self.node.push(STTNode::new(lp, rp, ep));
self.height.push(h);
v
}
}
// ---------- begin pointer ----------
#[derive(Clone, Copy)]
pub struct Pointer(u32);
impl Pointer {
pub fn new(v: usize) -> Self {
Self(v as u32)
}
pub fn null() -> Self {
Self(!0)
}
pub fn get(&self) -> Option<usize> {
if self.0 == !0 {
None
} else {
Some(self.0 as usize)
}
}
pub fn is_null(&self) -> bool {
self.get().is_none()
}
pub fn set(&mut self, v: usize) {
self.0 = v as u32
}
}
impl From<usize> for Pointer {
fn from(x: usize) -> Self {
Self::new(x)
}
}
impl Default for Pointer {
fn default() -> Self {
Self::null()
}
}
impl std::fmt::Debug for Pointer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(x) = self.get() {
write!(f, "{}", x)
} else {
write!(f, "null")
}
}
}
// ---------- end pointer ----------
mod util {
pub trait Join {
fn join(self, sep: &str) -> String;
}
impl<T, I> Join for I
where
I: Iterator<Item = T>,
T: std::fmt::Display,
{
fn join(self, sep: &str) -> String {
let mut s = String::new();
use std::fmt::*;
for (i, v) in self.enumerate() {
if i > 0 {
write!(&mut s, "{}", sep).ok();
}
write!(&mut s, "{}", v).ok();
}
s
}
}
}
// ---------- begin modint ----------
pub const fn pow_mod(mut r: u32, mut n: u32, m: u32) -> u32 {
let mut t = 1;
while n > 0 {
if n & 1 == 1 {
t = (t as u64 * r as u64 % m as u64) as u32;
}
r = (r as u64 * r as u64 % m as u64) as u32;
n >>= 1;
}
t
}
pub const fn primitive_root(p: u32) -> u32 {
let mut m = p - 1;
let mut f = [1; 30];
let mut k = 0;
let mut d = 2;
while d * d <= m {
if m % d == 0 {
f[k] = d;
k += 1;
}
while m % d == 0 {
m /= d;
}
d += 1;
}
if m > 1 {
f[k] = m;
k += 1;
}
let mut g = 1;
while g < p {
let mut ok = true;
let mut i = 0;
while i < k {
ok &= pow_mod(g, (p - 1) / f[i], p) > 1;
i += 1;
}
if ok {
break;
}
g += 1;
}
g
}
pub const fn is_prime(n: u32) -> bool {
if n <= 1 {
return false;
}
let mut d = 2;
while d * d <= n {
if n % d == 0 {
return false;
}
d += 1;
}
true
}
#[derive(Clone, Copy, PartialEq, Eq)]
pub struct ModInt<const M: u32>(u32);
impl<const M: u32> ModInt<{ M }> {
const REM: u32 = {
let mut t = 1u32;
let mut s = !M + 1;
let mut n = !0u32 >> 2;
while n > 0 {
if n & 1 == 1 {
t = t.wrapping_mul(s);
}
s = s.wrapping_mul(s);
n >>= 1;
}
t
};
const INI: u64 = ((1u128 << 64) % M as u128) as u64;
const IS_PRIME: () = assert!(is_prime(M));
const PRIMITIVE_ROOT: u32 = primitive_root(M);
const ORDER: usize = 1 << (M - 1).trailing_zeros();
const fn reduce(x: u64) -> u32 {
let _ = Self::IS_PRIME;
let b = (x as u32 * Self::REM) as u64;
let t = x + b * M as u64;
let mut c = (t >> 32) as u32;
if c >= M {
c -= M;
}
c as u32
}
const fn multiply(a: u32, b: u32) -> u32 {
Self::reduce(a as u64 * b as u64)
}
pub const fn new(v: u32) -> Self {
assert!(v < M);
Self(Self::reduce(v as u64 * Self::INI))
}
pub const fn const_mul(&self, rhs: Self) -> Self {
Self(Self::multiply(self.0, rhs.0))
}
pub const fn pow(&self, mut n: u64) -> Self {
let mut t = Self::new(1);
let mut r = *self;
while n > 0 {
if n & 1 == 1 {
t = t.const_mul(r);
}
r = r.const_mul(r);
n >>= 1;
}
t
}
pub const fn inv(&self) -> Self {
assert!(self.0 != 0);
self.pow(M as u64 - 2)
}
pub const fn get(&self) -> u32 {
Self::reduce(self.0 as u64)
}
pub const fn zero() -> Self {
Self::new(0)
}
pub const fn one() -> Self {
Self::new(1)
}
}
impl<const M: u32> Add for ModInt<{ M }> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
let mut v = self.0 + rhs.0;
if v >= M {
v -= M;
}
Self(v)
}
}
impl<const M: u32> Sub for ModInt<{ M }> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
let mut v = self.0 - rhs.0;
if self.0 < rhs.0 {
v += M;
}
Self(v)
}
}
impl<const M: u32> Mul for ModInt<{ M }> {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
self.const_mul(rhs)
}
}
impl<const M: u32> Div for ModInt<{ M }> {
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
self * rhs.inv()
}
}
impl<const M: u32> AddAssign for ModInt<{ M }> {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}
impl<const M: u32> SubAssign for ModInt<{ M }> {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl<const M: u32> MulAssign for ModInt<{ M }> {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}
impl<const M: u32> DivAssign for ModInt<{ M }> {
fn div_assign(&mut self, rhs: Self) {
*self = *self / rhs;
}
}
impl<const M: u32> Neg for ModInt<{ M }> {
type Output = Self;
fn neg(self) -> Self::Output {
if self.0 == 0 {
self
} else {
Self(M - self.0)
}
}
}
impl<const M: u32> std::fmt::Display for ModInt<{ M }> {
fn fmt<'a>(&self, f: &mut std::fmt::Formatter<'a>) -> std::fmt::Result {
write!(f, "{}", self.get())
}
}
impl<const M: u32> std::fmt::Debug for ModInt<{ M }> {
fn fmt<'a>(&self, f: &mut std::fmt::Formatter<'a>) -> std::fmt::Result {
write!(f, "{}", self.get())
}
}
impl<const M: u32> std::str::FromStr for ModInt<{ M }> {
type Err = std::num::ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let val = s.parse::<u32>()?;
Ok(ModInt::new(val))
}
}
impl<const M: u32> From<usize> for ModInt<{ M }> {
fn from(val: usize) -> ModInt<{ M }> {
ModInt::new((val % M as usize) as u32)
}
}
// ---------- end modint ----------
// ---------- begin precalc ----------
pub struct Precalc<const MOD: u32> {
fact: Vec<ModInt<MOD>>,
ifact: Vec<ModInt<MOD>>,
inv: Vec<ModInt<MOD>>,
}
impl<const MOD: u32> Precalc<MOD> {
pub fn new(size: usize) -> Self {
let mut fact = vec![ModInt::one(); size + 1];
let mut ifact = vec![ModInt::one(); size + 1];
let mut inv = vec![ModInt::one(); size + 1];
for i in 2..=size {
fact[i] = fact[i - 1] * ModInt::from(i);
}
ifact[size] = fact[size].inv();
for i in (2..=size).rev() {
inv[i] = ifact[i] * fact[i - 1];
ifact[i - 1] = ifact[i] * ModInt::from(i);
}
Self { fact, ifact, inv }
}
pub fn fact(&self, n: usize) -> ModInt<MOD> {
self.fact[n]
}
pub fn ifact(&self, n: usize) -> ModInt<MOD> {
self.ifact[n]
}
pub fn inv(&self, n: usize) -> ModInt<MOD> {
assert!(0 < n);
self.inv[n]
}
pub fn perm(&self, n: usize, k: usize) -> ModInt<MOD> {
if k > n {
return ModInt::zero();
}
self.fact[n] * self.ifact[n - k]
}
pub fn binom(&self, n: usize, k: usize) -> ModInt<MOD> {
if n < k {
return ModInt::zero();
}
self.fact[n] * self.ifact[k] * self.ifact[n - k]
}
}
// ---------- end precalc ----------
impl<const M: u32> Zero for ModInt<{ M }> {
fn zero() -> Self {
Self::zero()
}
fn is_zero(&self) -> bool {
self.0 == 0
}
}
impl<const M: u32> One for ModInt<{ M }> {
fn one() -> Self {
Self::one()
}
fn is_one(&self) -> bool {
self.get() == 1
}
}
// ---------- begin array op ----------
struct NTTPrecalc<const M: u32> {
sum_e: [ModInt<{ M }>; 30],
sum_ie: [ModInt<{ M }>; 30],
}
impl<const M: u32> NTTPrecalc<{ M }> {
const fn new() -> Self {
let cnt2 = (M - 1).trailing_zeros() as usize;
let root = ModInt::new(ModInt::<{ M }>::PRIMITIVE_ROOT);
let zeta = root.pow((M - 1) as u64 >> cnt2);
let mut es = [ModInt::zero(); 30];
let mut ies = [ModInt::zero(); 30];
let mut sum_e = [ModInt::zero(); 30];
let mut sum_ie = [ModInt::zero(); 30];
let mut e = zeta;
let mut ie = e.inv();
let mut i = cnt2;
while i >= 2 {
es[i - 2] = e;
ies[i - 2] = ie;
e = e.const_mul(e);
ie = ie.const_mul(ie);
i -= 1;
}
let mut now = ModInt::one();
let mut inow = ModInt::one();
let mut i = 0;
while i < cnt2 - 1 {
sum_e[i] = es[i].const_mul(now);
sum_ie[i] = ies[i].const_mul(inow);
now = ies[i].const_mul(now);
inow = es[i].const_mul(inow);
i += 1;
}
Self { sum_e, sum_ie }
}
}
struct NTTPrecalcHelper<const MOD: u32>;
impl<const MOD: u32> NTTPrecalcHelper<MOD> {
const A: NTTPrecalc<MOD> = NTTPrecalc::new();
}
pub trait ArrayAdd {
type Item;
fn add(&self, rhs: &[Self::Item]) -> Vec<Self::Item>;
}
impl<T> ArrayAdd for [T]
where
T: Zero + Copy,
{
type Item = T;
fn add(&self, rhs: &[Self::Item]) -> Vec<Self::Item> {
let mut c = vec![T::zero(); self.len().max(rhs.len())];
c[..self.len()].copy_from_slice(self);
c.add_assign(rhs);
c
}
}
pub trait ArrayAddAssign {
type Item;
fn add_assign(&mut self, rhs: &[Self::Item]);
}
impl<T> ArrayAddAssign for [T]
where
T: Add<Output = T> + Copy,
{
type Item = T;
fn add_assign(&mut self, rhs: &[Self::Item]) {
assert!(self.len() >= rhs.len());
self.iter_mut().zip(rhs).for_each(|(x, a)| *x = *x + *a);
}
}
impl<T> ArrayAddAssign for Vec<T>
where
T: Zero + Add<Output = T> + Copy,
{
type Item = T;
fn add_assign(&mut self, rhs: &[Self::Item]) {
if self.len() < rhs.len() {
self.resize(rhs.len(), T::zero());
}
self.as_mut_slice().add_assign(rhs);
}
}
pub trait ArraySub {
type Item;
fn sub(&self, rhs: &[Self::Item]) -> Vec<Self::Item>;
}
impl<T> ArraySub for [T]
where
T: Zero + Sub<Output = T> + Copy,
{
type Item = T;
fn sub(&self, rhs: &[Self::Item]) -> Vec<Self::Item> {
let mut c = vec![T::zero(); self.len().max(rhs.len())];
c[..self.len()].copy_from_slice(self);
c.sub_assign(rhs);
c
}
}
pub trait ArraySubAssign {
type Item;
fn sub_assign(&mut self, rhs: &[Self::Item]);
}
impl<T> ArraySubAssign for [T]
where
T: Sub<Output = T> + Copy,
{
type Item = T;
fn sub_assign(&mut self, rhs: &[Self::Item]) {
assert!(self.len() >= rhs.len());
self.iter_mut().zip(rhs).for_each(|(x, a)| *x = *x - *a);
}
}
impl<T> ArraySubAssign for Vec<T>
where
T: Zero + Sub<Output = T> + Copy,
{
type Item = T;
fn sub_assign(&mut self, rhs: &[Self::Item]) {
if self.len() < rhs.len() {
self.resize(rhs.len(), T::zero());
}
self.as_mut_slice().sub_assign(rhs);
}
}
pub trait ArrayDot {
type Item;
fn dot(&self, rhs: &[Self::Item]) -> Vec<Self::Item>;
}
impl<T> ArrayDot for [T]
where
T: Mul<Output = T> + Copy,
{
type Item = T;
fn dot(&self, rhs: &[Self::Item]) -> Vec<Self::Item> {
assert!(self.len() == rhs.len());
self.iter().zip(rhs).map(|p| *p.0 * *p.1).collect()
}
}
pub trait ArrayDotAssign {
type Item;
fn dot_assign(&mut self, rhs: &[Self::Item]);
}
impl<T> ArrayDotAssign for [T]
where
T: MulAssign + Copy,
{
type Item = T;
fn dot_assign(&mut self, rhs: &[Self::Item]) {
assert!(self.len() == rhs.len());
self.iter_mut().zip(rhs).for_each(|(x, a)| *x *= *a);
}
}
pub trait ArrayMul {
type Item;
fn mul(&self, rhs: &[Self::Item]) -> Vec<Self::Item>;
}
impl<T> ArrayMul for [T]
where
T: Zero + One + Copy,
{
type Item = T;
fn mul(&self, rhs: &[Self::Item]) -> Vec<Self::Item> {
if self.is_empty() || rhs.is_empty() {
return vec![];
}
let mut res = vec![T::zero(); self.len() + rhs.len() - 1];
for (i, a) in self.iter().enumerate() {
for (res, b) in res[i..].iter_mut().zip(rhs.iter()) {
*res = *res + *a * *b;
}
}
res
}
}
// transform でlen=1を指定すればNTTになる
pub trait ArrayConvolution {
type Item;
fn transform(&mut self, len: usize);
fn inverse_transform(&mut self, len: usize);
fn convolution(&self, rhs: &[Self::Item]) -> Vec<Self::Item>;
}
impl<const M: u32> ArrayConvolution for [ModInt<{ M }>] {
type Item = ModInt<{ M }>;
fn transform(&mut self, len: usize) {
let f = self;
let n = f.len();
let k = (n / len).trailing_zeros() as usize;
assert!(len << k == n);
assert!(k <= ModInt::<{ M }>::ORDER);
let pre = &NTTPrecalcHelper::<{ M }>::A;
for ph in 1..=k {
let p = len << (k - ph);
let mut now = ModInt::one();
for (i, f) in f.chunks_exact_mut(2 * p).enumerate() {
let (x, y) = f.split_at_mut(p);
for (x, y) in x.iter_mut().zip(y.iter_mut()) {
let l = *x;
let r = *y * now;
*x = l + r;
*y = l - r;
}
now *= pre.sum_e[(!i).trailing_zeros() as usize];
}
}
}
fn inverse_transform(&mut self, len: usize) {
let f = self;
let n = f.len();
let k = (n / len).trailing_zeros() as usize;
assert!(len << k == n);
assert!(k <= ModInt::<{ M }>::ORDER);
let pre = &NTTPrecalcHelper::<{ M }>::A;
for ph in (1..=k).rev() {
let p = len << (k - ph);
let mut inow = ModInt::one();
for (i, f) in f.chunks_exact_mut(2 * p).enumerate() {
let (x, y) = f.split_at_mut(p);
for (x, y) in x.iter_mut().zip(y.iter_mut()) {
let l = *x;
let r = *y;
*x = l + r;
*y = (l - r) * inow;
}
inow *= pre.sum_ie[(!i).trailing_zeros() as usize];
}
}
let ik = ModInt::new(2).inv().pow(k as u64);
for f in f.iter_mut() {
*f *= ik;
}
}
fn convolution(&self, rhs: &[Self::Item]) -> Vec<Self::Item> {
if self.len().min(rhs.len()) <= 32 {
return self.mul(rhs);
}
const PARAM: usize = 10;
let size = self.len() + rhs.len() - 1;
let mut k = 0;
while (size + (1 << k) - 1) >> k > PARAM {
k += 1;
}
let len = (size + (1 << k) - 1) >> k;
let mut f = vec![ModInt::zero(); len << k];
let mut g = vec![ModInt::zero(); len << k];
f[..self.len()].copy_from_slice(self);
g[..rhs.len()].copy_from_slice(rhs);
f.transform(len);
g.transform(len);
let mut buf = [ModInt::zero(); 2 * PARAM - 1];
let buf = &mut buf[..(2 * len - 1)];
let pre = &NTTPrecalcHelper::<{ M }>::A;
let mut now = ModInt::one();
for (i, (f, g)) in f
.chunks_exact_mut(2 * len)
.zip(g.chunks_exact(2 * len))
.enumerate()
{
let mut r = now;
for (f, g) in f.chunks_exact_mut(len).zip(g.chunks_exact(len)) {
buf.fill(ModInt::zero());
for (i, f) in f.iter().enumerate() {
for (buf, g) in buf[i..].iter_mut().zip(g.iter()) {
*buf = *buf + *f * *g;
}
}
f.copy_from_slice(&buf[..len]);
for (f, buf) in f.iter_mut().zip(buf[len..].iter()) {
*f = *f + r * *buf;
}
r = -r;
}
now *= pre.sum_e[(!i).trailing_zeros() as usize];
}
f.inverse_transform(len);
f.truncate(self.len() + rhs.len() - 1);
f
}
}
// ---------- end array op ----------
// ---------- begin trait ----------
use std::ops::*;
pub trait Zero: Sized + Add<Self, Output = Self> {
fn zero() -> Self;
fn is_zero(&self) -> bool;
}
pub trait One: Sized + Mul<Self, Output = Self> {
fn one() -> Self;
fn is_one(&self) -> bool;
}
pub trait Group: Zero + Sub<Output = Self> + Neg<Output = Self> {}
pub trait SemiRing: Zero + One {}
pub trait Ring: SemiRing + Group {}
pub trait Field: Ring + Div<Output = Self> {}
impl<T> Group for T where T: Zero + Sub<Output = Self> + Neg<Output = Self> {}
impl<T> SemiRing for T where T: Zero + One {}
impl<T> Ring for T where T: SemiRing + Group {}
impl<T> Field for T where T: Ring + Div<Output = Self> {}
pub fn zero<T: Zero>() -> T {
T::zero()
}
pub fn one<T: One>() -> T {
T::one()
}
pub fn pow<T: One + Clone>(mut r: T, mut n: usize) -> T {
let mut t = one();
while n > 0 {
if n & 1 == 1 {
t = t * r.clone();
}
r = r.clone() * r;
n >>= 1;
}
t
}
pub fn pow_sum<T: SemiRing + Clone>(mut r: T, mut n: usize) -> T {
let mut ans = T::zero();
let mut sum = T::one();
while n > 0 {
if n & 1 == 1 {
ans = ans * r.clone() + sum.clone();
}
sum = sum * (T::one() + r.clone());
r = r.clone() * r;
n >>= 1;
}
ans
}
// ---------- end trait ----------
#[derive(Clone, Copy, Default, Debug)]
pub struct Dual<T>(T, T);
impl<T> Dual<T> {
pub fn new(a: T, b: T) -> Self {
Self(a, b)
}
}
impl<T> Zero for Dual<T>
where
T: Zero,
{
fn zero() -> Self {
Self::new(T::zero(), T::zero())
}
fn is_zero(&self) -> bool {
self.0.is_zero() && self.1.is_zero()
}
}
impl<T> One for Dual<T>
where
T: One + Zero + Clone,
{
fn one() -> Self {
Self::new(T::one(), T::zero())
}
fn is_one(&self) -> bool {
self.0.is_one() && self.1.is_zero()
}
}
impl<T> Add for Dual<T>
where
T: Add<Output = T>,
{
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self::new(self.0 + rhs.0, self.1 + rhs.1)
}
}
impl<T> AddAssign for Dual<T>
where
T: Add<Output = T> + Clone,
{
fn add_assign(&mut self, rhs: Self) {
*self = self.clone() + rhs;
}
}
impl<T> Sub for Dual<T>
where
T: Sub<Output = T>,
{
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self::new(self.0 - rhs.0, self.1 - rhs.1)
}
}
impl<T> SubAssign for Dual<T>
where
T: Sub<Output = T> + Clone,
{
fn sub_assign(&mut self, rhs: Self) {
*self = self.clone() - rhs;
}
}
impl<T> Mul for Dual<T>
where
T: Clone + Add<Output = T> + Mul<Output = T>,
{
type Output = Self;
fn mul(self, rhs: Self) -> Self {
Self::new(
self.0.clone() * rhs.0.clone(),
self.0 * rhs.1 + self.1 * rhs.0,
)
}
}
impl<T> MulAssign for Dual<T>
where
T: Clone + Add<Output = T> + Mul<Output = T>,
{
fn mul_assign(&mut self, rhs: Self) {
*self = self.clone() * rhs;
}
}
akakimidori