MOD = 998244353 def main(): import sys sys.setrecursionlimit(1 << 25) N = int(sys.stdin.readline()) if N == 1: print(1) return parent = [0] * (N + 1) # parent[2] = p_2, etc. for i in range(2, N+1): parent[i] = int(sys.stdin.readline()) # Compute depth of each node using BFS from collections import deque depth = [0] * (N + 1) visited = [False] * (N + 1) q = deque() q.append(1) visited[1] = True depth[1] = 0 while q: u = q.popleft() for v in range(2, N+1): if parent[v] == u and not visited[v]: visited[v] = True depth[v] = depth[u] + 1 q.append(v) res = 1 for i in range(2, N+1): res = res * depth[i] % MOD print(res) if __name__ == '__main__': main()