n = int(input()) P = [-1] + list(map(lambda x: int(x)-1, input().split())) node = [[] for _ in range(n)] for i in range(1, n): node[P[i]].append(i) mod = 998244353 E = [] D = [-1] * n D[0] = 0 A = [0] while A: now = A.pop() E.append(now) for nxt in node[now]: if D[nxt] == -1: D[nxt] = D[now] + 1 A.append(nxt) dp = [[0, 1] for _ in range(n)] for now in E[::-1]: l = len(dp[now]) for i in range(l-2, -1, -1): dp[now][i] = (dp[now][i] + dp[now][i+1]) % mod nxt = P[now] if nxt == -1: continue a = len(dp[now]) b = len(dp[nxt]) c = a + b - 1 ndp = [0] * c for ai in range(a): for bi in range(b): ndp[ai+bi] = (ndp[ai+bi] + dp[now][ai] * dp[nxt][bi]) % mod dp[nxt] = ndp #print(dp[0]) print(sum(dp[0]) % mod)