結果

問題 No.1103 Directed Length Sum
ユーザー terasaterasa
提出日時 2022-11-20 12:55:19
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 3,620 bytes
コンパイル時間 249 ms
コンパイル使用メモリ 81,692 KB
実行使用メモリ 410,468 KB
最終ジャッジ日時 2024-09-21 13:27:11
合計ジャッジ時間 14,684 ms
ジャッジサーバーID
(参考情報)
judge3 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 59 ms
69,368 KB
testcase_01 AC 60 ms
68,480 KB
testcase_02 AC 1,557 ms
396,884 KB
testcase_03 AC 1,323 ms
410,468 KB
testcase_04 AC 2,030 ms
195,224 KB
testcase_05 TLE -
testcase_06 AC 1,449 ms
155,596 KB
testcase_07 AC 758 ms
98,200 KB
testcase_08 AC 508 ms
108,128 KB
testcase_09 AC 253 ms
92,064 KB
testcase_10 AC 625 ms
118,268 KB
testcase_11 AC 2,178 ms
205,672 KB
testcase_12 AC 1,329 ms
155,840 KB
testcase_13 AC 708 ms
119,140 KB
testcase_14 AC 223 ms
88,420 KB
testcase_15 AC 1,012 ms
139,460 KB
testcase_16 AC 2,564 ms
221,580 KB
testcase_17 AC 2,642 ms
227,036 KB
testcase_18 AC 651 ms
117,624 KB
testcase_19 AC 2,283 ms
208,432 KB
testcase_20 AC 303 ms
94,472 KB
testcase_21 AC 458 ms
105,248 KB
testcase_22 AC 1,879 ms
185,140 KB
testcase_23 AC 1,093 ms
143,652 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

from typing import List, Tuple, Callable, TypeVar
from typing import List, Tuple, Optional
import sys
import itertools
import heapq
import bisect
import math
from collections import deque, defaultdict
from functools import lru_cache, cmp_to_key

input = sys.stdin.readline

if __file__ != 'prog.py':
    sys.setrecursionlimit(10 ** 6)


def readints(): return map(int, input().split())
def readlist(): return list(readints())
def readstr(): return input()[:-1]


T = TypeVar('T')

input = sys.stdin.readline


class Rerooting:
    # reference: https://null-mn.hatenablog.com/entry/2020/04/14/124151
    # 適当な頂点vを根とする部分木に対して計算される値dp_vが、vの子c1, c2, ... ckを用いて
    # 下記のように表すことができる
    # dp_v = g(merge(f(dp_c1,c1), f(dp_c2,c2), ..., f(dp_ck,ck)), v)
    def __init__(self, N: int, E: List[Tuple[int, int]],
                 f: Callable[[T, int, int, int], T],
                 g: Callable[[T, int], T],
                 merge: Callable[[T, T], T], e: T,
                 root: int = 0):
        self.N = N
        self.E = E
        self.f = f
        self.g = g
        self.merge = merge
        self.e = e

        self.dp = [[self.e for _ in range(len(self.E[v]))] for v in range(self.N)]
        self._calculate(root)

    def _dfs1(self, root):
        stack = [(root, -1)]
        ret = [self.e] * self.N
        while stack:
            v, p = stack.pop()
            if v < 0:
                v = ~v
                acc = self.e
                for i, (c, d) in enumerate(self.E[v]):
                    if d == p:
                        continue
                    self.dp[v][i] = ret[d]
                    acc = self.merge(acc, self.f(ret[d], v, d, c))
                ret[v] = self.g(acc, v)
                continue

            stack.append((~v, p))
            for i, (c, d) in enumerate(self.E[v]):
                if d == p:
                    continue
                stack.append((d, v))

    def _dfs2(self, root):
        stack = [(root, -1, self.e)]
        while stack:
            v, p, from_par = stack.pop()
            for i, (c, d) in enumerate(self.E[v]):
                if d == p:
                    self.dp[v][i] = from_par
                    break
            ch = len(self.E[v])
            Sr = [self.e] * (ch + 1)
            for i in range(ch, 0, -1):
                c, d = self.E[v][i - 1]
                Sr[i - 1] = self.merge(Sr[i], self.f(self.dp[v][i - 1], v, d, c))
            Sl = self.e
            for i, (c, d) in enumerate(self.E[v]):
                if d != p:
                    val = self.merge(Sl, Sr[i + 1])
                    stack.append((d, v, self.g(val, v)))
                Sl = self.merge(Sl, self.f(self.dp[v][i], v, d, c))

    def _calculate(self, root):
        self._dfs1(root)
        self._dfs2(root)

    def solve(self, v):
        ans = self.e
        for i, (c, d) in enumerate(self.E[v]):
            ans = self.merge(ans, self.f(self.dp[v][i], v, d, c))
        return self.g(ans, v)


N = int(input())
E = [[] for _ in range(N)]
D = [0] * N
for _ in range(N - 1):
    a, b = readints()
    a -= 1
    b -= 1
    E[a].append((1, b))
    D[b] += 1

for i in range(N):
    if D[i] == 0:
        root = i
        break

mod = 10 ** 9 + 7


def f(a, v, ch, cost):
    return ((a[0] + a[1]) % mod, a[1])


def g(a, v):
    return (a[0], a[1] + 1)


def merge(a, b):
    return ((a[0] + b[0]) % mod, a[1] + b[1])


solver = Rerooting(N, E, f, g, merge, (0, 0), root=root)
ans = 0
for i in range(N):
    ans += solver.solve(i)[0]
    ans %= mod
print(ans)
0