import sys input=lambda:sys.stdin.readline().rstrip() def write(i,depth): if depth[i]!=-1: return write(par[i],depth) depth[i]=depth[par[i]]+1 mod=998244353 N=int(input()) par=[0]+[int(input())-1 for i in range(N-1)] depth=[-1 for i in range(N)] depth[0]=0 for i in range(1,N): write(i,depth) ans=1 for i in depth: ans=(ans*i)%mod print(ans)