結果

問題 No.2504 NOT Path Painting
ユーザー suisensuisen
提出日時 2023-07-22 17:41:32
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 498 ms / 2,000 ms
コード長 2,183 bytes
コンパイル時間 309 ms
コンパイル使用メモリ 87,284 KB
実行使用メモリ 119,296 KB
最終ジャッジ日時 2023-10-13 18:15:01
合計ジャッジ時間 10,044 ms
ジャッジサーバーID
(参考情報)
judge13 / judge12
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 168 ms
79,976 KB
testcase_01 AC 391 ms
84,896 KB
testcase_02 AC 344 ms
84,804 KB
testcase_03 AC 355 ms
84,764 KB
testcase_04 AC 347 ms
84,528 KB
testcase_05 AC 354 ms
84,712 KB
testcase_06 AC 347 ms
85,100 KB
testcase_07 AC 355 ms
85,240 KB
testcase_08 AC 350 ms
85,064 KB
testcase_09 AC 341 ms
84,856 KB
testcase_10 AC 346 ms
85,048 KB
testcase_11 AC 348 ms
84,892 KB
testcase_12 AC 407 ms
86,004 KB
testcase_13 AC 322 ms
83,048 KB
testcase_14 AC 332 ms
84,344 KB
testcase_15 AC 452 ms
111,872 KB
testcase_16 AC 471 ms
116,484 KB
testcase_17 AC 472 ms
113,604 KB
testcase_18 AC 396 ms
119,296 KB
testcase_19 AC 498 ms
115,544 KB
testcase_20 AC 398 ms
115,100 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from typing import List

def bottom_up_order(g: List[List[int]], root: int = 0):
    n = len(g)
    parent = [-1] * n
    bfs_order = [root]
    for u in bfs_order:
        for v in g[u]:
            if v != parent[u]:
                bfs_order.append(v)
                parent[v] = u
    return list(reversed(bfs_order)), parent

def solve(n: int, g: List[List[int]]):
    m = n * (n + 1) // 2

    denoms = []

    vs, par = bottom_up_order(g)

    # size of subtrees
    sub = [0] * n

    for u in vs:
        # # of paths including {u}
        p_u = m

        sub[u] = 1
        for v in g[u]:
            if v == par[u]:
                continue

            p_u -= sub[v] * (sub[v] + 1) // 2

            # # of paths including {u, v}
            p_uv = sub[v] * (n - sub[v])
            denoms.append(-(m - p_uv)) # ans -= m / (m - p_uv)

            sub[u] += sub[v]
        
        if par[u] != -1:
            sub_p = n - sub[u]
            p_u -= sub_p * (sub_p + 1) // 2
        
        denoms.append(m - p_u) # ans += m / (m - p_u)

    inv_denoms = modinvs(denoms)

    return sum(inv_denoms) % P * m % P

if __name__ == '__main__':
    P = 998244353

    def modinv(v: int):
        return pow(v, P - 2, P)
    
    def modinvs(a: List[int]):
        prod = 1
        for v in a:
            prod = prod * v % P
        prod_inv = modinv(prod)

        n = len(a)
        rprod = [0] * (n + 1)
        rprod[n] = 1
        for i in reversed(range(n)):
            rprod[i] = rprod[i + 1] * a[i] % P
        
        inv_a = [0] * n
        lprod = 1
        for i, v in enumerate(a):
            inv_a[i] = lprod * rprod[i + 1] % P * prod_inv % P
            lprod = lprod * v % P
        return inv_a
    
    answers = []

    T = int(sys.stdin.readline().rstrip())
    for _ in range(T):
        n = int(sys.stdin.readline().rstrip())
        g = [[] for _ in range(n)]
        for _ in range(n - 1):
            u, v = map(int, sys.stdin.readline().rstrip().split())
            u -= 1
            v -= 1
            g[u].append(v)
            g[v].append(u)
        
        answers.append(solve(n, g))
    
    print('\n'.join(map(str, answers)))
0