n=int(input());mod=998244353 p=[0]+list(map(lambda x:int(x)-1,input().split())) v=[[] for i in range(n)] t=[1]*n for i in range(1,n): v[p[i]].append(i) dp=[[0]*(n+1) for i in range(n)] for i in range(n-1,-1,-1): if len(v[i]): dp[i][1]=1;s=1 for j in v[i]: for l in range(t[i],0,-1): c=dp[i][l]*dp[j][0]%mod for k in range(t[j],0,-1): dp[i][l+k]+=dp[j][k]*dp[i][l]%mod dp[i][l+k]%=mod dp[i][l]=c t[i]+=t[j] for j in range(t[i],0,-1): dp[i][j-1]+=dp[i][j];dp[i][j-1]%=mod else: dp[i][0]=1;dp[i][1]=1 ans=0 for i in range(n+1): ans+=dp[0][i] print(ans%mod)