class ModInt: def __init__(self, value): self.v = value % 998244353 def __add__(self, other): return ModInt(self.v + (other.v if isinstance(other, ModInt) else other)) def __sub__(self, other): return ModInt(self.v - (other.v if isinstance(other, ModInt) else other)) def __mul__(self, other): return ModInt(self.v * (other.v if isinstance(other, ModInt) else other)) def __pow__(self, exp): return ModInt(pow(self.v, exp, 998244353)) def __truediv__(self, other): return ModInt(self.v * pow(other.v if isinstance(other, ModInt) else other, 998244353 - 2, 998244353)) def __int__(self): return self.v def __repr__(self): return str(self.v) def greater_check(a, b): # a,b: list of 3 ints abit = [0] * 5 bbit = [0] * 5 ak = sum(a) bk = sum(b) for i in range(3): abit[i] = a[i] + 1 + bk bbit[i] = b[i] + 1 + ak abit[3], abit[4] = ak + 1, ak bbit[3], bbit[4] = bk + 1, bk def bitsort(bit): for _ in range(5): for j in range(1, 5): if bit[j] != -1 and bit[j] == bit[j - 1]: bit[j - 1] = bit[j] + 1 bit[j] = -1 elif bit[j] > bit[j - 1]: bit[j], bit[j - 1] = bit[j - 1], bit[j] bitsort(abit) bitsort(bbit) for i in range(5): if abit[i] > bbit[i]: return True if abit[i] < bbit[i]: return False return False def bfs(G, N, start): par = [-1] * N look = [] for s in start: par[s] = s look.append(s) l = 0 while l < len(look): v = look[l] l += 1 for u in G[v]: if par[u] == -1: par[u] = v look.append(u) return look, par def main(): import sys sys.setrecursionlimit(10**7) input = sys.stdin.readline N = int(input()) G = [[] for _ in range(N)] for _ in range(N - 1): U, V = map(int, input().split()) U -= 1 V -= 1 G[U].append(V) G[V].append(U) # 1. BFS from node 0 dist, par = bfs(G, N, [0]) d1 = dist[-1] # 2. BFS from farthest node dist, par = bfs(G, N, [d1]) d2 = dist[-1] # 3. Recover diameter path diameter = [] t = d2 while True: diameter.append(t) if par[t] == t: break t = par[t] # 4. BFS from all diameter nodes dist, par = bfs(G, N, diameter) # 5. Compute depth array (max depth from each node) depth = [0] * N for i in reversed(dist): if par[i] != i: depth[par[i]] = max(depth[par[i]], depth[i] + 1) else: break ansd = [0, 0, 0] t = 0 for i in diameter: now = [t, len(diameter) - 1 - t, depth[i]] now.sort() if greater_check(ansd, now): ansd = now t += 1 ans = ModInt(2) ** (N + 2) K = sum(ansd) ans -= (ModInt(2) ** (N - K - 1)) * ( (ModInt(2) ** (ansd[0] + 2)) + (ModInt(2) ** (ansd[1] + 2)) + (ModInt(2) ** (ansd[2] + 2)) - 6 ) print(int(ans)) if __name__ == "__main__": main()