結果
| 問題 |
No.2949 Product on Tree
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2024-10-26 00:30:41 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 1,030 ms / 2,000 ms |
| コード長 | 2,451 bytes |
| コンパイル時間 | 346 ms |
| コンパイル使用メモリ | 82,432 KB |
| 実行使用メモリ | 132,116 KB |
| 最終ジャッジ日時 | 2024-10-26 00:31:25 |
| 合計ジャッジ時間 | 41,471 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 46 |
ソースコード
import sys, time, random
from collections import deque, Counter, defaultdict
def debug(*x):print('debug:',*x, file=sys.stderr)
input = lambda: sys.stdin.readline().rstrip()
ii = lambda: int(input())
mi = lambda: map(int, input().split())
li = lambda: list(mi())
inf = 2 ** 61 - 1
mod = 998244353
from collections import deque
def TreeDepth(s, graph):
inf = 2 ** 61 - 1
n = len(graph)
depth = [inf] * n
depth[s] = 0
q = deque()
q.append(s)
while q:
now = q.popleft()
for to in graph[now]:
if depth[to] == inf:
depth[to] = depth[now] + 1
q.append(to)
return depth
def TreeOrder(s, graph):
dist = TreeDepth(s, graph)
n = len(graph)
l = list(range(n))
l.sort(key=lambda x: dist[x])
return l
def subTree(s, graph):
l = TreeOrder(s, graph)
n = len(graph)
sub = [0] * n
for v in l[::-1]:
sub[v] = 1
for to in graph[v]:
sub[v] += sub[to]
return sub
def Treeheight(s, graph):
l = TreeOrder(s, graph)
n = len(graph)
height = [0] * n
for v in l[::-1]:
height[v] = max([height[to] for to in graph[v]] + [0]) + 1
return height
def EulerTour(s, graph):
n = len(graph)
done = [0] * n
Q = [~s, s] # 根をスタックに追加
ET = []
while Q:
i = Q.pop()
if i >= 0: # 行きがけの処理
done[i] = 1
ET.append(i)
for a in graph[i][::-1]:
if done[a]: continue
Q.append(~a) # 帰りがけの処理をスタックに追加
Q.append(a) # 行きがけの処理をスタックに追加
else: # 帰りがけの処理
ET.append(~i)
return ET
n = ii()
a = li()
graph = [[] for _ in range(n)]
for _ in range(n - 1):
u, v = mi()
u -= 1
v -= 1
graph[u].append(v)
graph[v].append(u)
L = TreeOrder(0, graph)
d = TreeDepth(0, graph)
dp = [1] * n
ans = 0
i2 = pow(2, -1, mod)
for v in L[::-1]:
s = 0
nans = 0
dp[v] = a[v]
f = 0
for to in graph[v]:
if d[to] > d[v]:
f = 1
dp[v] += a[v] * dp[to]
nans -= a[v] * dp[to] % mod * dp[to]
dp[v] %= mod
nans %= mod
s += dp[to]
nans += a[v] * s * s
nans *= i2
nans %= mod
ans += dp[v] + nans - a[v]
ans %= mod
if not f:
dp[v] = a[v]
print(ans)