結果

問題 No.2504 NOT Path Painting
ユーザー suisensuisen
提出日時 2023-07-22 17:21:43
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 680 ms / 2,000 ms
コード長 1,982 bytes
コンパイル時間 398 ms
コンパイル使用メモリ 81,352 KB
実行使用メモリ 231,212 KB
最終ジャッジ日時 2023-10-23 23:50:44
合計ジャッジ時間 10,458 ms
ジャッジサーバーID
(参考情報)
judge12 / judge11
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 61 ms
67,928 KB
testcase_01 AC 332 ms
82,172 KB
testcase_02 AC 286 ms
82,420 KB
testcase_03 AC 337 ms
82,228 KB
testcase_04 AC 339 ms
81,804 KB
testcase_05 AC 333 ms
82,684 KB
testcase_06 AC 334 ms
82,104 KB
testcase_07 AC 331 ms
82,848 KB
testcase_08 AC 326 ms
82,080 KB
testcase_09 AC 326 ms
82,904 KB
testcase_10 AC 343 ms
83,612 KB
testcase_11 AC 363 ms
82,856 KB
testcase_12 AC 355 ms
83,132 KB
testcase_13 AC 257 ms
80,416 KB
testcase_14 AC 248 ms
81,344 KB
testcase_15 AC 341 ms
108,112 KB
testcase_16 AC 357 ms
112,944 KB
testcase_17 AC 334 ms
111,260 KB
testcase_18 AC 265 ms
115,056 KB
testcase_19 AC 680 ms
231,212 KB
testcase_20 AC 530 ms
196,380 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from typing import List

def solve(n: int, g: List[List[int]]):
    m = n * (n + 1) // 2

    denoms = []

    def dfs(u: int, p: int) -> int:
        # # of paths including {u}
        p_u = m

        # size of subtree u
        sub_u = 1
        for v in g[u]:
            if v == p:
                continue
            
            # size of subtree v
            sub_v = dfs(v, u)

            p_u -= sub_v * (sub_v + 1) // 2

            # # of paths including {u, v}
            p_uv = sub_v * (n - sub_v)
            denoms.append(-(m - p_uv)) # ans -= m / (m - p_uv)

            sub_u += sub_v
        
        if p != -1:
            sub_p = n - sub_u
            p_u -= sub_p * (sub_p + 1) // 2
        
        denoms.append(m - p_u) # ans += m / (m - p_u)

        return sub_u
    
    dfs(0, -1)

    inv_denoms = modinvs(denoms)

    return sum(inv_denoms) % P * m % P

if __name__ == '__main__':
    sys.setrecursionlimit(100000)

    P = 998244353

    def modinv(v: int):
        return pow(v, P - 2, P)
    
    def modinvs(a: List[int]):
        prod = 1
        for v in a:
            prod = prod * v % P
        prod_inv = modinv(prod)

        n = len(a)
        rprod = [0] * (n + 1)
        rprod[n] = 1
        for i in reversed(range(n)):
            rprod[i] = rprod[i + 1] * a[i] % P
        
        inv_a = [0] * n
        lprod = 1
        for i, v in enumerate(a):
            inv_a[i] = lprod * rprod[i + 1] % P * prod_inv % P
            lprod = lprod * v % P
        return inv_a
    
    answers = []

    T = int(sys.stdin.readline().rstrip())
    for _ in range(T):
        n = int(sys.stdin.readline().rstrip())
        g = [[] for _ in range(n)]
        for _ in range(n - 1):
            u, v = map(int, sys.stdin.readline().rstrip().split())
            u -= 1
            v -= 1
            g[u].append(v)
            g[v].append(u)
        
        answers.append(solve(n, g))
    
    print('\n'.join(map(str, answers)))
0