結果
| 問題 |
No.2115 Making Forest Easy
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2022-10-28 23:11:07 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 2,468 bytes |
| コンパイル時間 | 256 ms |
| コンパイル使用メモリ | 82,000 KB |
| 実行使用メモリ | 173,192 KB |
| 最終ジャッジ日時 | 2024-07-06 02:24:22 |
| 合計ジャッジ時間 | 6,163 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | TLE * 2 -- * 48 |
ソースコード
import sys
sys.setrecursionlimit(10 ** 9)
import pypyjit
pypyjit.set_param('max_unroll_recursion=-1')
MOD = 998244353
n = int(input())
A = list(map(int, input().split()))
edges = [[] for _ in range(n)]
for _ in range(n - 1):
u, v = map(int, input().split())
u -= 1
v -= 1
edges[u].append(v)
edges[v].append(u)
B = [(a, i) for i, a in enumerate(A)]
B.sort(reverse = True)
used = [False] * n
pow2 = [1]
for _ in range(n):
pow2.append(pow2[-1] * 2 % MOD)
ans = 0
for a, i in B:
dist = [-1] * n
size = [0] * n
dist[i] = 0
stack = [~i, i]
while stack:
pos = stack.pop()
if pos >= 0:
for npos in edges[pos]:
if dist[npos] == -1:
dist[npos] = dist[pos] + 1
stack.append(~npos)
stack.append(npos)
else:
pos = ~pos
size[pos] += 1
for npos in edges[pos]:
size[pos] += size[npos]
dp = [0] * n
def dfs(pos, bpos):
ret = 1
for npos in edges[pos]:
if npos == bpos:
continue
if used[npos]:
ret *= pow2[size[npos] - 1]
else:
dfs(npos, pos)
ret *= dp[npos] + pow2[size[npos] - 1]
ret %= MOD
dp[pos] = ret
dfs(i, -1)
tot = 0
def dfs2(pos, bpos):
global tot
L = [1]
for npos in edges[pos]:
if used[npos]:
L.append(L[-1] * pow2[size[npos] - 1] % MOD)
elif npos == bpos:
L.append(L[-1] * dp[npos] % MOD)
else:
L.append(L[-1] * (dp[npos] + pow2[size[npos] - 1]) % MOD)
tot += L[-1]
tot %= MOD
R = [1]
for npos in edges[pos][::-1]:
if used[npos]:
R.append(R[-1] * pow2[size[npos] - 1] % MOD)
elif npos == bpos:
R.append(R[-1] * dp[npos] % MOD)
else:
R.append(R[-1] * (dp[npos] + pow2[size[npos] - 1]) % MOD)
R = R[::-1]
for ii, npos in enumerate(edges[pos]):
if npos == bpos or used[npos]:
continue
dp[pos] = L[ii] * R[ii + 1] % MOD
size[pos] = n - size[npos]
dfs2(npos, pos)
dfs2(i, -1)
ans += tot * a % MOD
ans %= MOD
used[i] = True
print(ans)