結果
| 問題 |
No.2949 Product on Tree
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-20 20:53:41 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 1,871 ms / 2,000 ms |
| コード長 | 1,470 bytes |
| コンパイル時間 | 862 ms |
| コンパイル使用メモリ | 82,952 KB |
| 実行使用メモリ | 396,404 KB |
| 最終ジャッジ日時 | 2025-03-20 20:54:54 |
| 合計ジャッジ時間 | 41,531 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 46 |
ソースコード
import sys
from collections import defaultdict
MOD = 998244353
def main():
sys.setrecursionlimit(1 << 25)
N = int(sys.stdin.readline())
A = list(map(int, sys.stdin.readline().split()))
edges = [[] for _ in range(N)]
for _ in range(N-1):
u, v = map(int, sys.stdin.readline().split())
u -= 1
v -= 1
edges[u].append(v)
edges[v].append(u)
total = 0
def dfs(u, parent, s_parent):
nonlocal total
res = 0
sum_s = 0
prev = 0
for v in edges[u]:
if v == parent:
continue
s_child = dfs(v, u, s_parent)
# S += s_parent * s_child * A[u]
contrib = (s_parent * s_child) % MOD
contrib = (contrib * A[u]) % MOD
total = (total + contrib) % MOD
# S += prev * s_child * A[u]
contrib = (prev * s_child) % MOD
contrib = (contrib * A[u]) % MOD
total = (total + contrib) % MOD
prev = (prev + s_child) % MOD
sum_s = (sum_s + s_child) % MOD
# After processing all children, add the contribution from current node's children to itself
contrib = (sum_s * A[u]) % MOD
total = (total + contrib) % MOD
# Compute the return value
ret = ( (s_parent + sum_s) * A[u] + A[u] ) % MOD
return ret
dfs(0, -1, 0)
print(total % MOD)
if __name__ == "__main__":
main()
lam6er