n=int(input()) e=[[] for i in range(n)] for i in range(n-1): a,b=map(int,input().split()) a-=1 b-=1 e[a]+=[b] e[b]+=[a] M=998244353 v=[0]*n p=[[1,0] for i in range(n)] q=[0] while len(q)>0: s=q[-1] if v[s]==0: v[s]=1 for t in e[s]: if v[t]==0: q+=[t] else: for t in e[s]: if v[t]==0: p[s][1]=p[s][1]*(p[t][0]+p[t][1])+p[s][0]*p[t][1] p[s][1]%=M p[s][0]=p[s][0]*(p[t][0]+p[t][1]) p[s][0]%=M p[s][1]+=p[s][0] p[s][1]%=M v[s]=0 q.pop() print(p[0][1])