from collections import defaultdict import sys sys.setrecursionlimit(10 ** 6) MOD = 998244353 N = int(input()) adj = defaultdict(list) for i in range(N-1): u, v = map(lambda x: int(x)-1, input().split()) adj[u].append(v) adj[v].append(u) def tree_dp(v, par): dp0 = 1 # 頂点 v は孤立点 dp1 = 0 # 頂点 v は端点 dp2 = 0 # 頂点 v は中間点 for to in adj[v]: if to == par: continue r_dp0, r_dp1, r_dp2 = tree_dp(to, v) pp0 = pp1 = pp2 = 0 dp0, pp0 = pp0, dp0 dp1, pp1 = pp1, dp1 dp2, pp2 = pp2, dp2 # v は孤立点のまま : 孤立点 * (端点 or 中間点) dp0 += pp0 * (r_dp1 + r_dp2) # v は端点のまま : 端点 * (端点 or 中間点) dp1 += pp1 * (r_dp1 + r_dp2) # v は中間点のまま : 中間点 * (端点 or 中間点) dp2 += pp2 * (r_dp1 + r_dp2) # v を孤立点から端点へ : 孤立点 * (孤立点 or 端点) dp1 += pp0 * (r_dp0 + r_dp1) # v を端点から中間点へ : 端点 * (孤立点 or 端点) dp2 += pp1 * (r_dp0 + r_dp1) dp0 %= MOD dp1 %= MOD dp2 %= MOD return dp0, dp1, dp2 _, dp1, dp2 = tree_dp(0, -1) ans = (dp1 + dp2) % MOD print(ans)