結果

問題 No.1507 Road Blocked
ユーザー NoneNone
提出日時 2021-10-01 13:29:49
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 538 ms / 2,000 ms
コード長 11,238 bytes
コンパイル時間 851 ms
コンパイル使用メモリ 86,820 KB
実行使用メモリ 121,888 KB
最終ジャッジ日時 2023-09-26 05:32:57
合計ジャッジ時間 17,751 ms
ジャッジサーバーID
(参考情報)
judge12 / judge11
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 76 ms
71,332 KB
testcase_01 AC 70 ms
71,164 KB
testcase_02 AC 71 ms
71,220 KB
testcase_03 AC 237 ms
118,180 KB
testcase_04 AC 407 ms
121,700 KB
testcase_05 AC 401 ms
121,732 KB
testcase_06 AC 399 ms
121,888 KB
testcase_07 AC 398 ms
121,652 KB
testcase_08 AC 391 ms
121,500 KB
testcase_09 AC 389 ms
121,776 KB
testcase_10 AC 397 ms
121,388 KB
testcase_11 AC 538 ms
121,692 KB
testcase_12 AC 521 ms
121,624 KB
testcase_13 AC 527 ms
121,460 KB
testcase_14 AC 513 ms
121,504 KB
testcase_15 AC 526 ms
121,488 KB
testcase_16 AC 412 ms
121,620 KB
testcase_17 AC 412 ms
121,628 KB
testcase_18 AC 412 ms
121,624 KB
testcase_19 AC 420 ms
121,716 KB
testcase_20 AC 411 ms
121,344 KB
testcase_21 AC 399 ms
121,424 KB
testcase_22 AC 434 ms
121,592 KB
testcase_23 AC 409 ms
121,692 KB
testcase_24 AC 418 ms
121,604 KB
testcase_25 AC 420 ms
121,776 KB
testcase_26 AC 407 ms
121,604 KB
testcase_27 AC 393 ms
118,212 KB
testcase_28 AC 413 ms
121,720 KB
testcase_29 AC 420 ms
121,596 KB
testcase_30 AC 409 ms
121,628 KB
testcase_31 AC 418 ms
121,540 KB
testcase_32 AC 410 ms
121,416 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

class Tree():
    def __init__(self, n, decrement=1):
        self.n = n
        self.edges = [[] for _ in range(n)]
        self._edge_label = [[] for _ in range(n)]
        self.root = None
        self.size = [1]*n       # number of nodes in subtree
        self.decrement = decrement

    def add_edge(self, u, v, i):
        u, v = u-self.decrement, v-self.decrement
        self.edges[u].append(v)
        self.edges[v].append(u)
        self._edge_label[u].append((v,i))
        self._edge_label[v].append((u,i))

    def add_edges(self, edges):
        for i, p in enumerate(edges):
            u, v = p
            u, v = u-self.decrement, v-self.decrement
            self.edges[u].append(v)
            self.edges[v].append(u)
            self._edge_label[u].append((v, i))
            self._edge_label[v].append((u, i))

    def set_root(self, root):
        root -= self.decrement
        self.depth = [-1]*self.n
        self.root = root
        self.par = [-1]*self.n
        self.depth[root] = 0
        self.edge_label = [-1]*self.n
        self.order = []
        self.chi = [[] for _ in range(self.n)]
        next_set = [root]
        while next_set:
            p = next_set.pop()
            self.order.append(p)
            for q, i in self._edge_label[p]:
                if self.depth[q] != -1: continue
                self.par[q] = p
                self.chi[p].append(q)
                self.depth[q] = self.depth[p]+1
                self.edge_label[q]=i
                next_set.append(q)
        for p in self.order[::-1]:
            for q in self.edges[p]:
                if self.par[p] == q: continue
                self.size[p] += self.size[q]

    def diameter(self, path=False):
        # assert self.root is not None
        u = self.depth.index(max(self.depth))
        dist = [-1]*self.n
        dist[u] = 0
        prev = [-1]*self.n
        next_set = [u]
        while next_set:
            p = next_set.pop()
            for q in self.edges[p]:
                if dist[q] != -1: continue
                dist[q] = dist[p]+1
                prev[q] = p
                next_set.append(q)
        d = max(dist)
        if path:
            v = w = dist.index(d)
            path = [v+1]
            while w != u:
                w = prev[w]
                path.append(w+self.decrement)
            return d, v+self.decrement, u+self.decrement, path
        else: return d

    def rerooting(self, op, merge, id):
        # assert self.root is not None
        dp1 = [id] * self.n
        dp2 = [id] * self.n
        for p in self.order[::-1]:
            t = id
            for q in self.edges[p]:
                if self.par[p] == q: continue
                dp2[q] = t
                t = merge(t, op(dp1[q], p, q))
            t = id
            for q in self.edges[p][::-1]:
                if self.par[p] == q: continue
                dp2[q] = merge(t, dp2[q])
                t = merge(t, op(dp1[q], p, q))
            dp1[p] = t
        for q in self.order[1:]:
            pq = self.par[q]
            dp2[q] = op(merge(dp2[q], dp2[pq]), q, pq)
            dp1[q] = merge(dp1[q], dp2[q])
        return dp1

    def heavy_light_decomposition(self):
        """
        return flat array of lists of heavy edges (1-indexed if decrement=True)
        """
        # assert self.root is not None
        self.vid = [-1]*self.n
        self.hld = [-1]*self.n
        self.head = [-1]*self.n
        self.head[self.root] = self.root
        self.heavy_node = [-1]*self.n
        next_set = [self.root]
        for i in range(self.n):
            """ for tree graph, dfs ends in N times """
            p = next_set.pop()
            self.vid[p] = i
            self.hld[i] = p+self.decrement
            maxs = 0
            for q in self.edges[p]:
                """ encode direction of Heavy edge into heavy_node """
                if self.par[p] == q: continue
                if maxs < self.size[q]:
                    maxs = self.size[q]
                    self.heavy_node[p] = q
            for q in self.edges[p]:
                """ determine "head" of heavy edge """
                if self.par[p] == q or self.heavy_node[p] == q: continue
                self.head[q] = q
                next_set.append(q)
            if self.heavy_node[p] != -1:
                self.head[self.heavy_node[p]] = self.head[p]
                next_set.append(self.heavy_node[p])
        return self.hld

    def lca(self, u, v):
        # assert self.head is not None
        u, v = u-self.decrement, v-self.decrement
        while True:
            if self.vid[u] > self.vid[v]: u, v = v, u
            if self.head[u] != self.head[v]:
                v = self.par[self.head[v]]
            else:
                return u + self.decrement

    def path(self, u, v):
        """ return the path array of the shortest path on u-v """
        p = self.lca(u, v)
        u, v, p = u-self.decrement, v-self.decrement, p-self.decrement
        R = []
        while u != p:
            yield u+self.decrement
            u = self.par[u]
        yield p+self.decrement
        while v != p:
            R.append(v)
            v = self.par[v]
        for v in reversed(R):
            yield v+self.decrement

    def distance(self, u, v):
        # assert self.head is not None
        p = self.lca(u, v)
        u, v, p = u-self.decrement, v-self.decrement, p-self.decrement
        return self.depth[u] + self.depth[v] - 2*self.depth[p]

    def find(self, u, v, x):
        return self.distance(u,x)+self.distance(x,v)==self.distance(u,v)

    def path_to_list(self, u, v, edge_query=False):
        """
        transform a half-open interval into segments on the self.hld, which is the heavy edge list

        edge_query: map from edge (par,chi) to point (chi)
                (note: The root is never updated)
        """
        # assert self.head is not None
        u, v = u-self.decrement, v-self.decrement
        while True:
            if self.vid[u] > self.vid[v]: u, v = v, u
            if self.head[u] != self.head[v]:
                yield self.vid[self.head[v]], self.vid[v] + 1
                v = self.par[self.head[v]]
            else:
                yield self.vid[u] + edge_query, self.vid[v] + 1
                return

    def ver_to_idx(self, u):
        """ return index on self.hld corresponding to vertex u """
        return self.vid[u-self.decrement]

    def idx_to_ver(self, i):
        """ from index i on self.hld to vertex-index """
        return self.hld[i]

    def idx_to_edge(self, i):
        """ from index i on self.hld to edge-index """
        return self.edge_label[self.hld[i]-self.decrement]

    def subtree_query(self, u):
        u -= self.decrement
        return self.vid[u], self.vid[u] + self.size[u]

    def top_down(self,dp):
        def merge(dp_chi,dp_par):
            return dp_chi^dp_par
        for chi in self.order[1:]:
            par=self.par[chi]
            dp[chi]=merge(dp[chi], dp[par])
        return dp

    def top_down_edge_query(self,dp):
        def merge(dp_chi,dp_par):
            return dp_chi^dp_par
        for p in self.order[1+len(self.edges[self.root]):]:
            chi,par=self.edge_label[p],self.edge_label[self.par[p]]
            dp[chi]=merge(dp[chi], dp[par])
        return dp

    def bottom_up(self,dp):
        def merge(dp_chi,dp_par):
            return dp_chi+dp_par
        for par in self.order[::-1]:
            for chi in self.edges[par]:
                if self.par[par] == chi: continue
                dp[par]=merge(dp[chi],dp[par])
        return dp

    def draw(self):
        import matplotlib.pyplot as plt
        import networkx as nx

        G = nx.Graph()
        for x in range(self.n):
            for y in self.edges[x]:
                G.add_edge(x + self.decrement, y + self.decrement)
        pos = nx.spring_layout(G)
        nx.draw_networkx(G, pos)
        plt.axis("off")
        plt.show()

#########################################################################################################

def example_tree(N=10,show=True,decrement=1):
    global input

    # decrement=True: create 1-indexed tree

    def find(x):
        tmp=[]
        while parents[x]>=0:
            tmp.append(x)
            x=parents[x]
        for y in tmp: parents[y]=x
        return x

    def union(x,y):
        x,y=find(x),find(y)
        if x==y: return
        if parents[x]>parents[y]: x,y=y,x
        parents[x]+=parents[y]
        parents[y]=x

    def same(x,y):
        return find(x)==find(y)

    import random
    # N = random.randint(2, N)
    parents=[-1]*N
    edges=[]
    input_data=[]
    input_data.append(str(N))
    for i in range(N-1):
        while True:
            j=random.randint(0,N-1)
            if same(i,j): continue
            union(i,j)
            edges.append((i+decrement,j+decrement))

            break

    for i,j in edges:
        input_data.append(" ".join(map(str,(i,j))))
    input_data=iter(input_data)
    input=lambda:next(input_data)


    if show:
        print("#######################")
        print(N)
        for p in edges:
            print(*p)
        print("#######################")
    return edges

def example_special_tree(N, type, show=True, decrement=1):
    global input

    edges = []

    if type=="path":
        for i in range(N-1):
            edges.append((i+decrement,i+1+decrement))
    if type=="star":
        for i in range(N-1):
            edges.append((i,i+1+decrement))
    if type=="binary":
        i = 1
        while True:
            if (i<<1) >= N: break
            edges.append((i-1+decrement, (i<<1)-1+decrement))
            if (i<<1)+1 >= N: break
            edges.append((i-1+decrement, (i<<1)+decrement))
            i += 1

    input_data=[]
    input_data.append(str(N))

    for i,j in edges:
        input_data.append(" ".join(map(str,(i,j))))
    input_data=iter(input_data)
    input=lambda:next(input_data)

    if show:
        print("#######################")
        print(N)
        for p in edges:
            print(*p)
        print("#######################")
    return edges

def draw(edges, decrement=1):
    import matplotlib.pyplot as plt
    import networkx as nx
    N = len(edges)+1
    G = nx.Graph()
    for x in range(N):
        for y in edges[x]:
            G.add_edge(x + decrement, y + decrement)
    pos = nx.spring_layout(G)
    nx.draw_networkx(G, pos)
    plt.axis("off")
    plt.show()

def example():
    global input
    example = iter(
        """
5 2
1 5
2 3
3 5
4 3
1 2 3
1 4 2
        """
            .strip().split("\n"))

    input = lambda: next(example)

#########################################################################################################
import sys
input = sys.stdin.readline

# example_tree(N=10, decrement=True)
# example_special_tree(N=10, type="star", decrement=True)

# N,Q=map(int, input().split())

MOD=998244353
N=int(input())
T = Tree(N,decrement=0)
for i in range(N-1):
    x, y = map(int, input().split())
    x-=1
    y-=1
    T.add_edge(x,y,i)

T.set_root(0)
res=0
for i in range(1,N):
    p=T.size[i]
    q=N-p
    res+=p*q
    res%=MOD

print((1-2*res*pow(N*(N-1),MOD-2,MOD)*pow(N-1,MOD-2,MOD))%MOD)
0