p = 998244353 n = int(input()) deg = [0] * n for i in range(n - 1): a, b = map(lambda x: int(x) - 1, input().split()) deg[a] += 1 deg[b] += 1 ans = sum([pow(2, i, p) for i in deg]) if max(deg) == n - 1: ans -= 2 print(ans % p)