結果

問題 No.2504 NOT Path Painting
ユーザー suisensuisen
提出日時 2023-07-22 17:41:32
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 360 ms / 2,000 ms
コード長 2,183 bytes
コンパイル時間 310 ms
コンパイル使用メモリ 81,904 KB
実行使用メモリ 116,180 KB
最終ジャッジ日時 2024-09-15 14:04:39
合計ジャッジ時間 7,079 ms
ジャッジサーバーID
(参考情報)
judge6 / judge2
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 66 ms
67,456 KB
testcase_01 AC 257 ms
80,876 KB
testcase_02 AC 244 ms
80,748 KB
testcase_03 AC 247 ms
80,640 KB
testcase_04 AC 243 ms
80,616 KB
testcase_05 AC 246 ms
80,620 KB
testcase_06 AC 247 ms
81,184 KB
testcase_07 AC 248 ms
80,748 KB
testcase_08 AC 247 ms
80,748 KB
testcase_09 AC 242 ms
80,872 KB
testcase_10 AC 243 ms
80,964 KB
testcase_11 AC 243 ms
81,124 KB
testcase_12 AC 300 ms
81,672 KB
testcase_13 AC 216 ms
79,208 KB
testcase_14 AC 226 ms
80,324 KB
testcase_15 AC 338 ms
111,760 KB
testcase_16 AC 354 ms
116,064 KB
testcase_17 AC 338 ms
113,724 KB
testcase_18 AC 281 ms
116,100 KB
testcase_19 AC 360 ms
116,180 KB
testcase_20 AC 291 ms
115,108 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from typing import List

def bottom_up_order(g: List[List[int]], root: int = 0):
    n = len(g)
    parent = [-1] * n
    bfs_order = [root]
    for u in bfs_order:
        for v in g[u]:
            if v != parent[u]:
                bfs_order.append(v)
                parent[v] = u
    return list(reversed(bfs_order)), parent

def solve(n: int, g: List[List[int]]):
    m = n * (n + 1) // 2

    denoms = []

    vs, par = bottom_up_order(g)

    # size of subtrees
    sub = [0] * n

    for u in vs:
        # # of paths including {u}
        p_u = m

        sub[u] = 1
        for v in g[u]:
            if v == par[u]:
                continue

            p_u -= sub[v] * (sub[v] + 1) // 2

            # # of paths including {u, v}
            p_uv = sub[v] * (n - sub[v])
            denoms.append(-(m - p_uv)) # ans -= m / (m - p_uv)

            sub[u] += sub[v]
        
        if par[u] != -1:
            sub_p = n - sub[u]
            p_u -= sub_p * (sub_p + 1) // 2
        
        denoms.append(m - p_u) # ans += m / (m - p_u)

    inv_denoms = modinvs(denoms)

    return sum(inv_denoms) % P * m % P

if __name__ == '__main__':
    P = 998244353

    def modinv(v: int):
        return pow(v, P - 2, P)
    
    def modinvs(a: List[int]):
        prod = 1
        for v in a:
            prod = prod * v % P
        prod_inv = modinv(prod)

        n = len(a)
        rprod = [0] * (n + 1)
        rprod[n] = 1
        for i in reversed(range(n)):
            rprod[i] = rprod[i + 1] * a[i] % P
        
        inv_a = [0] * n
        lprod = 1
        for i, v in enumerate(a):
            inv_a[i] = lprod * rprod[i + 1] % P * prod_inv % P
            lprod = lprod * v % P
        return inv_a
    
    answers = []

    T = int(sys.stdin.readline().rstrip())
    for _ in range(T):
        n = int(sys.stdin.readline().rstrip())
        g = [[] for _ in range(n)]
        for _ in range(n - 1):
            u, v = map(int, sys.stdin.readline().rstrip().split())
            u -= 1
            v -= 1
            g[u].append(v)
            g[v].append(u)
        
        answers.append(solve(n, g))
    
    print('\n'.join(map(str, answers)))
0