結果

問題 No.2115 Making Forest Easy
ユーザー 👑 rin204rin204
提出日時 2022-10-28 23:11:07
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,468 bytes
コンパイル時間 256 ms
コンパイル使用メモリ 82,000 KB
実行使用メモリ 173,192 KB
最終ジャッジ日時 2024-07-06 02:24:22
合計ジャッジ時間 6,163 ms
ジャッジサーバーID
(参考情報)
judge1 / judge3
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 34 ms
61,280 KB
testcase_01 AC 38 ms
52,868 KB
testcase_02 TLE -
testcase_03 TLE -
testcase_04 -- -
testcase_05 -- -
testcase_06 -- -
testcase_07 -- -
testcase_08 -- -
testcase_09 -- -
testcase_10 -- -
testcase_11 -- -
testcase_12 -- -
testcase_13 -- -
testcase_14 -- -
testcase_15 -- -
testcase_16 -- -
testcase_17 -- -
testcase_18 -- -
testcase_19 -- -
testcase_20 -- -
testcase_21 -- -
testcase_22 -- -
testcase_23 -- -
testcase_24 -- -
testcase_25 -- -
testcase_26 -- -
testcase_27 -- -
testcase_28 -- -
testcase_29 -- -
testcase_30 -- -
testcase_31 -- -
testcase_32 -- -
testcase_33 -- -
testcase_34 -- -
testcase_35 -- -
testcase_36 -- -
testcase_37 -- -
testcase_38 -- -
testcase_39 -- -
testcase_40 -- -
testcase_41 -- -
testcase_42 -- -
testcase_43 -- -
testcase_44 -- -
testcase_45 -- -
testcase_46 -- -
testcase_47 -- -
testcase_48 -- -
testcase_49 -- -
testcase_50 -- -
testcase_51 -- -
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
sys.setrecursionlimit(10 ** 9)
import pypyjit
pypyjit.set_param('max_unroll_recursion=-1')

MOD = 998244353

n = int(input())
A = list(map(int, input().split()))
edges = [[] for _ in range(n)]
for _ in range(n - 1):
    u, v = map(int, input().split())
    u -= 1
    v -= 1
    edges[u].append(v)
    edges[v].append(u)

B = [(a, i) for i, a in enumerate(A)]
B.sort(reverse = True)
used = [False] * n
pow2 = [1]
for _ in range(n):
    pow2.append(pow2[-1] * 2 % MOD)

ans = 0
for a, i in B:
    dist = [-1] * n
    size = [0] * n
    dist[i] = 0
    stack = [~i, i]
    while stack:
        pos = stack.pop()
        if pos >= 0:
            for npos in edges[pos]:
                if dist[npos] == -1:
                    dist[npos] = dist[pos] + 1
                    stack.append(~npos)
                    stack.append(npos)
        else:
            pos = ~pos
            size[pos] += 1
            for npos in edges[pos]:
                size[pos] += size[npos]
    
    dp = [0] * n
    def dfs(pos, bpos):
        ret = 1
        for npos in edges[pos]:
            if npos == bpos:
                continue
            if used[npos]:
                ret *= pow2[size[npos] - 1]
            else:
                dfs(npos, pos)
                ret *= dp[npos] + pow2[size[npos] - 1]
            ret %= MOD
        dp[pos] = ret

    dfs(i, -1)
    
    tot = 0
    def dfs2(pos, bpos):
        global tot
        
        L = [1]
        for npos in edges[pos]:
            if used[npos]:
                L.append(L[-1] * pow2[size[npos] - 1] % MOD)
            elif npos == bpos:
                L.append(L[-1] * dp[npos] % MOD)
            else:
                L.append(L[-1] * (dp[npos] + pow2[size[npos] - 1]) % MOD)

        tot += L[-1]
        tot %= MOD
            
        R = [1]
        for npos in edges[pos][::-1]:
            if used[npos]:
                R.append(R[-1] * pow2[size[npos] - 1] % MOD)
            elif npos == bpos:
                R.append(R[-1] * dp[npos] % MOD)
            else:
                R.append(R[-1] * (dp[npos] + pow2[size[npos] - 1]) % MOD)
        R = R[::-1]
        
        for ii, npos in enumerate(edges[pos]):
            if npos == bpos or used[npos]:
                continue
            dp[pos] = L[ii] * R[ii + 1] % MOD
            size[pos] = n - size[npos]
            dfs2(npos, pos)

    dfs2(i, -1)
    ans += tot * a % MOD
    
    ans %= MOD
    used[i] = True
    

print(ans)


0