結果

問題 No.2504 NOT Path Painting
ユーザー suisensuisen
提出日時 2023-07-22 17:21:43
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 681 ms / 2,000 ms
コード長 1,982 bytes
コンパイル時間 375 ms
コンパイル使用メモリ 82,616 KB
実行使用メモリ 231,460 KB
最終ジャッジ日時 2024-09-22 16:57:33
合計ジャッジ時間 8,946 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 62 ms
68,904 KB
testcase_01 AC 342 ms
83,000 KB
testcase_02 AC 298 ms
82,932 KB
testcase_03 AC 341 ms
82,856 KB
testcase_04 AC 348 ms
82,288 KB
testcase_05 AC 337 ms
83,532 KB
testcase_06 AC 337 ms
82,780 KB
testcase_07 AC 344 ms
83,668 KB
testcase_08 AC 346 ms
82,752 KB
testcase_09 AC 343 ms
83,388 KB
testcase_10 AC 352 ms
84,284 KB
testcase_11 AC 351 ms
83,256 KB
testcase_12 AC 342 ms
83,768 KB
testcase_13 AC 265 ms
81,044 KB
testcase_14 AC 257 ms
82,264 KB
testcase_15 AC 378 ms
109,496 KB
testcase_16 AC 386 ms
114,028 KB
testcase_17 AC 380 ms
112,088 KB
testcase_18 AC 288 ms
115,604 KB
testcase_19 AC 681 ms
231,460 KB
testcase_20 AC 551 ms
195,680 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from typing import List

def solve(n: int, g: List[List[int]]):
    m = n * (n + 1) // 2

    denoms = []

    def dfs(u: int, p: int) -> int:
        # # of paths including {u}
        p_u = m

        # size of subtree u
        sub_u = 1
        for v in g[u]:
            if v == p:
                continue
            
            # size of subtree v
            sub_v = dfs(v, u)

            p_u -= sub_v * (sub_v + 1) // 2

            # # of paths including {u, v}
            p_uv = sub_v * (n - sub_v)
            denoms.append(-(m - p_uv)) # ans -= m / (m - p_uv)

            sub_u += sub_v
        
        if p != -1:
            sub_p = n - sub_u
            p_u -= sub_p * (sub_p + 1) // 2
        
        denoms.append(m - p_u) # ans += m / (m - p_u)

        return sub_u
    
    dfs(0, -1)

    inv_denoms = modinvs(denoms)

    return sum(inv_denoms) % P * m % P

if __name__ == '__main__':
    sys.setrecursionlimit(100000)

    P = 998244353

    def modinv(v: int):
        return pow(v, P - 2, P)
    
    def modinvs(a: List[int]):
        prod = 1
        for v in a:
            prod = prod * v % P
        prod_inv = modinv(prod)

        n = len(a)
        rprod = [0] * (n + 1)
        rprod[n] = 1
        for i in reversed(range(n)):
            rprod[i] = rprod[i + 1] * a[i] % P
        
        inv_a = [0] * n
        lprod = 1
        for i, v in enumerate(a):
            inv_a[i] = lprod * rprod[i + 1] % P * prod_inv % P
            lprod = lprod * v % P
        return inv_a
    
    answers = []

    T = int(sys.stdin.readline().rstrip())
    for _ in range(T):
        n = int(sys.stdin.readline().rstrip())
        g = [[] for _ in range(n)]
        for _ in range(n - 1):
            u, v = map(int, sys.stdin.readline().rstrip().split())
            u -= 1
            v -= 1
            g[u].append(v)
            g[v].append(u)
        
        answers.append(solve(n, g))
    
    print('\n'.join(map(str, answers)))
0