結果

問題 No.1103 Directed Length Sum
ユーザー terasaterasa
提出日時 2022-11-20 12:55:19
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,620 bytes
コンパイル時間 162 ms
コンパイル使用メモリ 81,768 KB
実行使用メモリ 409,260 KB
最終ジャッジ日時 2023-10-21 12:10:13
合計ジャッジ時間 15,338 ms
ジャッジサーバーID
(参考情報)
judge13 / judge11
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 59 ms
67,628 KB
testcase_01 AC 60 ms
67,904 KB
testcase_02 AC 1,589 ms
393,108 KB
testcase_03 AC 1,273 ms
409,260 KB
testcase_04 AC 2,109 ms
194,720 KB
testcase_05 TLE -
testcase_06 AC 1,368 ms
154,840 KB
testcase_07 AC 374 ms
97,412 KB
testcase_08 AC 533 ms
107,928 KB
testcase_09 AC 262 ms
91,196 KB
testcase_10 AC 701 ms
118,020 KB
testcase_11 AC 2,277 ms
205,592 KB
testcase_12 AC 1,382 ms
155,792 KB
testcase_13 AC 714 ms
118,908 KB
testcase_14 AC 253 ms
88,120 KB
testcase_15 AC 1,063 ms
138,816 KB
testcase_16 AC 2,604 ms
221,344 KB
testcase_17 AC 2,654 ms
227,096 KB
testcase_18 AC 704 ms
117,276 KB
testcase_19 AC 2,329 ms
208,304 KB
testcase_20 AC 302 ms
93,652 KB
testcase_21 AC 473 ms
104,844 KB
testcase_22 AC 1,947 ms
184,616 KB
testcase_23 AC 1,164 ms
142,952 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

from typing import List, Tuple, Callable, TypeVar
from typing import List, Tuple, Optional
import sys
import itertools
import heapq
import bisect
import math
from collections import deque, defaultdict
from functools import lru_cache, cmp_to_key

input = sys.stdin.readline

if __file__ != 'prog.py':
    sys.setrecursionlimit(10 ** 6)


def readints(): return map(int, input().split())
def readlist(): return list(readints())
def readstr(): return input()[:-1]


T = TypeVar('T')

input = sys.stdin.readline


class Rerooting:
    # reference: https://null-mn.hatenablog.com/entry/2020/04/14/124151
    # 適当な頂点vを根とする部分木に対して計算される値dp_vが、vの子c1, c2, ... ckを用いて
    # 下記のように表すことができる
    # dp_v = g(merge(f(dp_c1,c1), f(dp_c2,c2), ..., f(dp_ck,ck)), v)
    def __init__(self, N: int, E: List[Tuple[int, int]],
                 f: Callable[[T, int, int, int], T],
                 g: Callable[[T, int], T],
                 merge: Callable[[T, T], T], e: T,
                 root: int = 0):
        self.N = N
        self.E = E
        self.f = f
        self.g = g
        self.merge = merge
        self.e = e

        self.dp = [[self.e for _ in range(len(self.E[v]))] for v in range(self.N)]
        self._calculate(root)

    def _dfs1(self, root):
        stack = [(root, -1)]
        ret = [self.e] * self.N
        while stack:
            v, p = stack.pop()
            if v < 0:
                v = ~v
                acc = self.e
                for i, (c, d) in enumerate(self.E[v]):
                    if d == p:
                        continue
                    self.dp[v][i] = ret[d]
                    acc = self.merge(acc, self.f(ret[d], v, d, c))
                ret[v] = self.g(acc, v)
                continue

            stack.append((~v, p))
            for i, (c, d) in enumerate(self.E[v]):
                if d == p:
                    continue
                stack.append((d, v))

    def _dfs2(self, root):
        stack = [(root, -1, self.e)]
        while stack:
            v, p, from_par = stack.pop()
            for i, (c, d) in enumerate(self.E[v]):
                if d == p:
                    self.dp[v][i] = from_par
                    break
            ch = len(self.E[v])
            Sr = [self.e] * (ch + 1)
            for i in range(ch, 0, -1):
                c, d = self.E[v][i - 1]
                Sr[i - 1] = self.merge(Sr[i], self.f(self.dp[v][i - 1], v, d, c))
            Sl = self.e
            for i, (c, d) in enumerate(self.E[v]):
                if d != p:
                    val = self.merge(Sl, Sr[i + 1])
                    stack.append((d, v, self.g(val, v)))
                Sl = self.merge(Sl, self.f(self.dp[v][i], v, d, c))

    def _calculate(self, root):
        self._dfs1(root)
        self._dfs2(root)

    def solve(self, v):
        ans = self.e
        for i, (c, d) in enumerate(self.E[v]):
            ans = self.merge(ans, self.f(self.dp[v][i], v, d, c))
        return self.g(ans, v)


N = int(input())
E = [[] for _ in range(N)]
D = [0] * N
for _ in range(N - 1):
    a, b = readints()
    a -= 1
    b -= 1
    E[a].append((1, b))
    D[b] += 1

for i in range(N):
    if D[i] == 0:
        root = i
        break

mod = 10 ** 9 + 7


def f(a, v, ch, cost):
    return ((a[0] + a[1]) % mod, a[1])


def g(a, v):
    return (a[0], a[1] + 1)


def merge(a, b):
    return ((a[0] + b[0]) % mod, a[1] + b[1])


solver = Rerooting(N, E, f, g, merge, (0, 0), root=root)
ans = 0
for i in range(N):
    ans += solver.solve(i)[0]
    ans %= mod
print(ans)
0