import sys sys.setrecursionlimit(10**7) mod = 998244353 n = int(input()) G = [[] for _ in range(n)] for i in range(1,n): p = int(input())-1 G[p].append(i) def dfs(u,d): res = 1 for v in G[u]: res = res*dfs(v,d+1)%mod if d >= 2: res = (res+1)%mod return res print(dfs(0,0))