結果

問題 No.2504 NOT Path Painting
ユーザー suisensuisen
提出日時 2023-02-21 11:09:20
言語 PyPy3
(7.3.15)
結果
RE  
(最新)
AC  
(最初)
実行時間 -
コード長 4,000 bytes
コンパイル時間 1,312 ms
コンパイル使用メモリ 81,392 KB
実行使用メモリ 70,560 KB
最終ジャッジ日時 2023-10-23 23:29:14
合計ジャッジ時間 3,018 ms
ジャッジサーバーID
(参考情報)
judge14 / judge11
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 RE -
testcase_01 RE -
testcase_02 RE -
testcase_03 RE -
testcase_04 RE -
testcase_05 RE -
testcase_06 RE -
testcase_07 RE -
testcase_08 RE -
testcase_09 RE -
testcase_10 RE -
testcase_11 RE -
testcase_12 RE -
testcase_13 RE -
testcase_14 RE -
testcase_15 RE -
testcase_16 RE -
testcase_17 RE -
testcase_18 RE -
testcase_19 RE -
testcase_20 RE -
権限があれば一括ダウンロードができます

ソースコード

diff #

from collections import deque
from typing import List

P = 998244353

def inv(n):
    return pow(n, P - 2, P)

def edge_num(n: int):
    return (n * (n + 1)) >> 1

def solve(n: int, g: List[List[int]]):
    m = edge_num(n)

    inv_m = inv(m)

    par_ = [0] * n
    siz_ = [1] * n
    def precalc(u: int, p: int):
        par_[u] = p
        for v in g[u]:
            if v != p:
                siz_[u] += precalc(v, u)
        return siz_[u]

    precalc(0, -1)

    # u の親を p としたときの、部分木 u のサイズ
    def subtree_size(u: int, p: int):
        if par_[u] == p:
            return siz_[u]
        else:
            return n - siz_[p]

    # 解説の t (隣接点が ng1 の場合)
    def calc_t_1(u: int, ng1: int):
        return n - subtree_size(ng1, u)

    # 解説の t (隣接点が ng1, ng2 の場合)
    def calc_t_2(u: int, ng1: int, ng2: int):
        return n - subtree_size(ng1, u) - subtree_size(ng2, u)

    ans_f = [0] * n

    for x in range(n):
        # u_{x,x}(x)
        u_xx_x = m - sum(edge_num(subtree_size(y, x)) for y in g[x])
        ans_f[x] = m * inv(m - u_xx_x) % P

    ans_g = [[0] * n for _ in range(n)]

    # par[x][y] := x を根とする木における y の親
    par = [[-1] * n for _ in range(n)]

    # x, y, A_{x,y}, B_{x,y}
    dq = deque()
    for x in range(n):
        ans_g[x][x] = ans_f[x]

        # s_x(x)
        s_x_x = n
        # u_{x,x}(x)
        u_xx_x = edge_num(n) - sum(edge_num(subtree_size(y, x)) for y in g[x])
        for y in g[x]:
            # s_x(y)
            s_x_y = subtree_size(y, x)
            # u_{x,y}(x)
            u_xy_x = u_xx_x - s_x_y * (s_x_x - s_x_y)
            Axy = u_xy_x * ans_f[x] % P
            Bxy = 0

            par[x][y] = x
            dq.append((x, y, Axy, Bxy))

    while dq:
        x, y, Axy, Bxy = dq.popleft()

        # x を根とした木における y の親
        par_y = par[x][y]
        # s_x(y)
        s_x_y = subtree_size(y, par_y)

        # u_{x,y}(y)
        u_xy_y = edge_num(s_x_y) - sum(edge_num(subtree_size(w, y)) for w in g[y] if w != par_y)
        # t_{x,y}(y)
        t_xy_y = s_x_y
        
        ans_g[x][y] = (Axy + u_xy_y * ans_f[y] + Bxy) % P
        prev_z, z = y, par_y
        while z != x:
            next_z = par[x][z]
            # t_{x,y}(z)
            # z の 1 つ前と 1 つ後が N_{x,y}(z) に含まれる頂点
            t_xy_z = calc_t_2(z, prev_z, next_z)
            ans_g[x][y] = (ans_g[x][y] + t_xy_y * t_xy_z * ans_g[y][z]) % P
            prev_z, z = z, next_z
        # t_{x,y}(x)
        t_xy_x = calc_t_1(x, prev_z)
        ans_g[x][y] = ((1 + ans_g[x][y] * inv_m) % P * inv(1 - (t_xy_x * t_xy_y * inv_m))) % P

        for w in g[y]:
            if w == par_y:
                continue

            # t_{x,w}(x)
            t_xw_x = t_xy_x
            # s_{x}(w)
            s_x_w = subtree_size(w, y)
            # t_{x,w}(y)
            t_xw_y = t_xy_y - s_x_w
            # u_{x,w}(y)
            u_xw_y = u_xy_y - s_x_w * (s_x_y - s_x_w)

            # A_{x,w}
            Axw = (Axy + u_xw_y * ans_f[y]) % P
            # B_{x,w}
            Bxw = (Bxy + t_xw_y * t_xw_x * ans_g[x][y]) % P

            # Bxw に sum_{z in Pxy-{y}} t_{x,w}(y) * t_{x,w}(z) * g(y,z) を足して更新
            prev_z, z = y, par_y
            while z != x:
                next_z = par[x][z]
                # t_{x,w}(z)
                t_xw_z = calc_t_2(z, prev_z, next_z)
                # t_{x,w}(y) * t_{x,w}(z) * g(y, z)
                Bxw = (Bxw + t_xw_y * t_xw_z * ans_g[y][z]) % P
                prev_z, z = z, next_z

            par[x][w] = y
            dq.append((x, w, Axw, Bxw))

    ans = 1
    for x in range(n):
        for y in range(x + 1):
            ans = (ans + ans_g[x][y] * inv_m) % P
    print(ans)

n = int(input())
g = [[] for _ in range(n)]

for _ in range(n - 1):
    u, v = map(int, input().split())
    u -= 1
    v -= 1
    g[u].append(v)
    g[v].append(u)

solve(n, g)
0