mod = 998244353 N = int(input()) C = [0] * N for _ in range(N - 1): a, b = map(int, input().split()) C[a - 1] += 1 C[b - 1] += 1 if sum(C) - max(C) == N - 1: print((sum(pow(2, v, mod) for v in C) - 2) % mod) else: print(sum(pow(2, v, mod) for v in C) % mod)