結果

問題 No.1103 Directed Length Sum
ユーザー NoneNone
提出日時 2021-03-02 00:04:25
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 10,616 bytes
コンパイル時間 234 ms
コンパイル使用メモリ 82,320 KB
実行使用メモリ 503,848 KB
最終ジャッジ日時 2024-04-14 04:07:15
合計ジャッジ時間 11,365 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 41 ms
56,024 KB
testcase_01 AC 40 ms
55,624 KB
testcase_02 AC 1,095 ms
349,324 KB
testcase_03 AC 1,423 ms
503,848 KB
testcase_04 AC 2,250 ms
243,384 KB
testcase_05 TLE -
testcase_06 -- -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

class Tree():
    def __init__(self, n, decrement=1):
        self.n = n
        self.edges = [[] for _ in range(n)]
        self.root = None
        self.size = [1]*n       # number of nodes in subtree
        self.decrement = decrement

    def add_edge(self, u, v, i):
        u, v = u-self.decrement, v-self.decrement
        self.edges[u].append(v)
        self.edges[v].append(u)

    def add_edges(self, edges):
        for i, p in enumerate(edges):
            u, v = p
            u, v = u-self.decrement, v-self.decrement
            self.edges[u].append(v)
            self.edges[v].append(u)

    def set_root(self, root):
        root -= self.decrement
        self.depth = [-1]*self.n
        self.root = root
        self.par = [-1]*self.n
        self.depth[root] = 0
        self.order = [root]
        next_set = [root]
        while next_set:
            p = next_set.pop()
            for q in self.edges[p]:
                if self.depth[q] != -1: continue
                self.par[q] = p
                self.depth[q] = self.depth[p]+1
                self.order.append(q)
                next_set.append(q)
        for p in self.order[::-1]:
            for q in self.edges[p]:
                if self.par[p] == q: continue
                self.size[p] += self.size[q]

    def diameter(self, path=False):
        # assert self.root is not None
        u = self.depth.index(max(self.depth))
        dist = [-1]*self.n
        dist[u] = 0
        prev = [-1]*self.n
        next_set = [u]
        while next_set:
            p = next_set.pop()
            for q in self.edges[p]:
                if dist[q] != -1: continue
                dist[q] = dist[p]+1
                prev[q] = p
                next_set.append(q)
        d = max(dist)
        if path:
            v = w = dist.index(d)
            path = [v+1]
            while w != u:
                w = prev[w]
                path.append(w+self.decrement)
            return d, v+self.decrement, u+self.decrement, path
        else: return d

    def rerooting(self, op, merge, id):
        # assert self.root is not None
        dp1 = [id] * self.n
        dp2 = [id] * self.n
        for p in self.order[::-1]:
            t = id
            for q in self.edges[p]:
                if self.par[p] == q: continue
                dp2[q] = t
                t = merge(t, op(dp1[q], p, q))
            t = id
            for q in self.edges[p][::-1]:
                if self.par[p] == q: continue
                dp2[q] = merge(t, dp2[q])
                t = merge(t, op(dp1[q], p, q))
            dp1[p] = t
        for q in self.order[1:]:
            pq = self.par[q]
            dp2[q] = op(merge(dp2[q], dp2[pq]), q, pq)
            dp1[q] = merge(dp1[q], dp2[q])
        return dp1

    def heavy_light_decomposition(self):
        """
        return flat array of lists of heavy edges (1-indexed if decrement=True)
        """
        # assert self.root is not None
        self.vid = [-1]*self.n
        self.hld = [-1]*self.n
        self.head = [-1]*self.n
        self.head[self.root] = self.root
        self.heavy_node = [-1]*self.n
        next_set = [self.root]
        for i in range(self.n):
            """ for tree graph, dfs ends in N times """
            p = next_set.pop()
            self.vid[p] = i
            self.hld[i] = p+self.decrement
            maxs = 0
            for q in self.edges[p]:
                """ encode direction of Heavy edge into heavy_node """
                if self.par[p] == q: continue
                if maxs < self.size[q]:
                    maxs = self.size[q]
                    self.heavy_node[p] = q
            for q in self.edges[p]:
                """ determine "head" of heavy edge """
                if self.par[p] == q or self.heavy_node[p] == q: continue
                self.head[q] = q
                next_set.append(q)
            if self.heavy_node[p] != -1:
                self.head[self.heavy_node[p]] = self.head[p]
                next_set.append(self.heavy_node[p])
        return self.hld

    def lca(self, u, v):
        # assert self.head is not None
        u, v = u-self.decrement, v-self.decrement
        while True:
            if self.vid[u] > self.vid[v]: u, v = v, u
            if self.head[u] != self.head[v]:
                v = self.par[self.head[v]]
            else:
                return u + self.decrement

    def path(self, u, v):
        """ return the path array of the shortest path on u-v """
        p = self.lca(u, v)
        u, v, p = u-self.decrement, v-self.decrement, p-self.decrement
        R = []
        while u != p:
            yield u+self.decrement
            u = self.par[u]
        yield p+self.decrement
        while v != p:
            R.append(v)
            v = self.par[v]
        for v in reversed(R):
            yield v+self.decrement

    def distance(self, u, v):
        # assert self.head is not None
        p = self.lca(u, v)
        u, v, p = u-self.decrement, v-self.decrement, p-self.decrement
        return self.depth[u] + self.depth[v] - 2*self.depth[p]

    def find(self, u, v, x):
        return self.distance(u,x)+self.distance(x,v)==self.distance(u,v)

    def path_to_list(self, u, v, edge_query=False):
        """
        transform a half-open interval into segments on the self.hld, which is the heavy edge list

        edge_query: map from edge (par,chi) to point (chi)
                (note: The root is never updated)
        """
        # assert self.head is not None
        u, v = u-self.decrement, v-self.decrement
        while True:
            if self.vid[u] > self.vid[v]: u, v = v, u
            if self.head[u] != self.head[v]:
                yield self.vid[self.head[v]], self.vid[v] + 1
                v = self.par[self.head[v]]
            else:
                yield self.vid[u] + edge_query, self.vid[v] + 1
                return

    def ver_to_idx(self, u):
        """ return index on self.hld corresponding to vertex u """
        return self.vid[u-self.decrement]

    def idx_to_ver(self, i):
        """ from index i on self.hld to vertex-index """
        return self.hld[i]

    def idx_to_edge(self, i):
        """ from index i on self.hld to edge-index """
        return self.edge_label[self.hld[i]-self.decrement]

    def subtree_query(self, u):
        u -= self.decrement
        return self.vid[u], self.vid[u] + self.size[u]

    def top_down(self,dp):
        def merge(dp_chi,dp_par):
            return dp_chi^dp_par
        for chi in self.order[1:]:
            par=self.par[chi]
            dp[chi]=merge(dp[chi], dp[par])
        return dp

    def top_down_edge_query(self,dp):
        def merge(dp_chi,dp_par):
            return dp_chi^dp_par
        for p in self.order[1+len(self.edges[self.root]):]:
            chi,par=self.edge_label[p],self.edge_label[self.par[p]]
            dp[chi]=merge(dp[chi], dp[par])
        return dp

    def bottom_up(self,dp):
        def merge(dp_chi,dp_par):
            return dp_chi+dp_par
        for par in self.order[::-1]:
            for chi in self.edges[par]:
                if self.par[par] == chi: continue
                dp[par]=merge(dp[chi],dp[par])
        return dp

    def draw(self):
        import matplotlib.pyplot as plt
        import networkx as nx

        G = nx.Graph()
        for x in range(self.n):
            for y in self.edges[x]:
                G.add_edge(x + self.decrement, y + self.decrement)
        pos = nx.spring_layout(G)
        nx.draw_networkx(G, pos)
        plt.axis("off")
        plt.show()



#########################################################################################################

def make_tree(N, show=True, decrement=1):
    # decrement=True: create 1-indexed tree

    def find(x):
        tmp = []
        while parents[x] >= 0:
            tmp.append(x)
            x = parents[x]
        for y in tmp: parents[y] = x
        return x

    def union(x, y):
        x, y = find(x), find(y)
        if x == y: return
        if parents[x] > parents[y]: x, y = y, x
        parents[x] += parents[y]
        parents[y] = x

    def same(x, y):
        return find(x) == find(y)

    import random
    # N = random.randint(2, N)
    parents = [-1]*N
    edges = []
    for i in range(N-1):
        while True:
            j = random.randint(0, N-1)
            if same(i,j): continue
            union(i,j)
            edges.append((i+decrement,j+decrement))
            break
    if show:
        print("#######################")
        print(N)
        for p in edges:
            print(*p)
        print("#######################")
    return edges

def make_special_tree(N, type, show=True, decrement=1):
    edges = []
    if type=="path":
        for i in range(N-1):
            edges.append((i+decrement,i+1+decrement))
    if type=="star":
        for i in range(N-1):
            edges.append((i,i+1+decrement))
    if type=="binary":
        i = 1
        while True:
            if (i<<1) >= N: break
            edges.append((i-1+decrement, (i<<1)-1+decrement))
            if (i<<1)+1 >= N: break
            edges.append((i-1+decrement, (i<<1)+decrement))
            i += 1
    if show:
        print("#######################")
        print(N)
        for p in edges:
            print(*p)
        print("#######################")
    return edges

def draw(edges, decrement=1):
    import matplotlib.pyplot as plt
    import networkx as nx
    N = len(edges)+1
    G = nx.Graph()
    for x in range(N):
        for y in edges[x]:
            G.add_edge(x + decrement, y + decrement)
    pos = nx.spring_layout(G)
    nx.draw_networkx(G, pos)
    plt.axis("off")
    plt.show()



#########################################################################################################
import sys
input = sys.stdin.readline

# example()


def dfs(start=0,goal=None):
    parents={}
    p,t=start,0
    parents[p]=-2
    next_set=[(p,t)]
    if not edges[p]:
        return p
    while next_set:
        p,t=next_set.pop()
        if not edges[p]:
            return p
        for q in edges[p]:
            if q in parents:
                continue
            parents[q]=p
            next_set.append((q,t+1))
    return -1

MOD=10**9+7
N= int(input())
edges=[[] for _ in range(N)]
T = Tree(N,decrement=1)
for i in range(N-1):
    x, y = map(int, input().split())
    T.add_edge(x,y,i)
    edges[y-1].append(x-1)

root=dfs()

T.set_root(root+1)

res=0
for i in range(N):
    if i==root: continue
    q=i
    p=edges[i][0]
    a=T.size[q]
    b=T.depth[p]+1
    res+=a*b
    res%=MOD

print(res)
0