import sys input = sys.stdin.readline sys.setrecursionlimit(2*10**5+10) def dfs(v, pv=-1): a, b = [], [] for nv in G[v]: if nv==pv: continue dfs(nv, v) a.append(dp[nv][0]) b.append(dp[nv][1]) mult_l = [1] mult_r = [1] for i in range(len(a)): mult_l.append(mult_l[-1]*(a[i]+b[i])%MOD) mult_r.append(mult_r[-1]*(a[len(a)-1-i]+b[len(a)-1-i])%MOD) dp[v][0] = dp[v][0]*(a[i]+b[i])%MOD dp[v][1] = dp[v][1]*(a[i]+b[i])%MOD for i in range(len(a)): dp[v][1] = (dp[v][1]+mult_l[i]*mult_r[len(a)-1-i]%MOD*b[i]%MOD)%MOD N = int(input()) G = [[] for _ in range(N)] for _ in range(N-1): A, B = map(int, input().split()) G[A-1].append(B-1) G[B-1].append(A-1) dp = [[1]*2 for _ in range(N)] MOD = 998244353 dfs(0) print(dp[0][1])