import java.io.PrintWriter import scala.collection.mutable.* import scala.io.StdIn.* import scala.util.chaining.* import scala.math.* import scala.reflect.ClassTag import scala.util.* import scala.annotation.tailrec import scala.collection.mutable import scala.language.strictEquality @main def main = inline val MOD = 998244353 val n = readLine.toInt val parent = Array.fill(n - 1){readLine.toInt - 1} val children = Array.fill(n){ArrayBuffer[Int]()} for (p, i) <- parent.zipWithIndex do children(p).addOne(i + 1) val depth = Array.fill(n){0} val queue = mutable.ArrayDeque[Int](0) while queue.nonEmpty do val top = queue.removeHead() for child <- children(top) do depth(child) = depth(top) + 1 queue.append(child) val result = depth.foldLeft(1L){(acc, v) => acc * max(v, 1) % MOD} println(result)