import sys sys.setrecursionlimit(10**8) #import pypyjit #pypyjit.set_param('max_unroll_recursion=-1') MOD = 998244353 N = int(input()) jisu = [0] * N G = [list() for _ in range(N)] for i in range(N - 1): u, v = map(int, input().split()) u -= 1 v -= 1 G[u].append(v) G[v].append(u) jisu[u] += 1 jisu[v] += 1 # 木DP DP = [[[0] * (2) for _ in range(3)] for _ in range(N)] def dfs(pos, pre): soseki = [1, 1] memo2 = 1 for nex in G[pos]: if nex == pre: continue dfs(nex, pos) for j in range(2): soseki[j] *= DP[nex][0][0] soseki[j] %= MOD # 茶色 memo2 *= DP[nex][0][0] + DP[nex][1][0] memo2 %= MOD # 茶 DP[pos][2][0] = memo2 # 総積は求められている DP[pos][0][0] = soseki[1] # 赤未 DP[pos][1][0] = soseki[0] # 緑未 memo01 = 0 # 赤既 memo11 = 0 # 緑既 for nex in G[pos]: if nex == pre: continue memo01 += soseki[1] * pow(DP[nex][1][0], -1, MOD) * (DP[nex][2][0] + DP[nex][1][1]) memo01 %= MOD memo11 += soseki[0] * pow(DP[nex][0][0], -1, MOD) * (DP[nex][2][0] + DP[nex][0][1]) memo11 %= MOD # 赤既、緑既 DP[pos][0][1] = memo01 DP[pos][1][1] = memo11 dfs(0, -1) ans = DP[0][0][1] + DP[0][1][1] + DP[0][2][0] ans %= MOD # スターグラフ対応 jisu.sort() for i in range(N - 1): if jisu[i] != 1: break else: ans -= 2 ans %= MOD print(ans)