import sys MOD = 998244353 def main(): sys.setrecursionlimit(1_000_000) N = int(sys.stdin.readline()) P = [] if N >= 2: P = list(map(int, sys.stdin.readline().split())) children = [[] for _ in range(N + 1)] for i in range(2, N + 1): parent = P[i - 2] children[parent].append(i) # subtree sizes (parents have smaller indices: process from N down to 1) sz = [1] * (N + 1) for v in range(N, 0, -1): s = 1 for c in children[v]: s += sz[c] sz[v] = s # to reduce intermediate polynomial growth, multiply children in increasing subtree size for v in range(1, N + 1): children[v].sort(key=lambda x: sz[x]) dp = [None] * (N + 1) # convolution truncated to limit def convolve_trunc(a, b, limit): # commutative: iterate smaller outer to reduce Python overhead if len(a) > len(b): a, b = b, a max_len = min(limit, len(a) + len(b) - 2) + 1 res = [0] * max_len b_len = len(b) for i, ai in enumerate(a): if ai == 0: continue maxj = min(b_len - 1, limit - i) # j from 0..maxj for j in range(maxj + 1): bj = b[j] if bj: res[i + j] = (res[i + j] + ai * bj) % MOD return res # bottom-up DP for v in range(N, 0, -1): limit = sz[v] conv = [1] # polynomial for sum of children's S for c in children[v]: conv = convolve_trunc(conv, dp[c], limit) # dp_v[s] = prefix sum of conv up to s dpv = [0] * (limit + 1) run = 0 for s in range(limit + 1): if s < len(conv): run += conv[s] dpv[s] = run % MOD dp[v] = dpv # free children dp (tree: used only once) for c in children[v]: dp[c] = None ans = sum(dp[1]) % MOD print(ans) if __name__ == "__main__": main()