import sys input = sys.stdin.readline N = int(input()) mod = 998244353 e = [[] for _ in range(N + 1)] for x in range(2, N + 1): p = int(input()) e[p].append(x) s = [1] vis = [0] * (N + 1) vis[1] = 1 depth = [0] * (N + 1) while len(s): x = s.pop() for y in e[x]: if vis[y]: continue vis[y] = 1 s.append(y) depth[y] = depth[x] + 1 res = 1 for x in range(2, N + 1): if depth[x] > 1: res *= 2 res %= mod print(res)