結果
問題 |
No.2949 Product on Tree
|
ユーザー |
![]() |
提出日時 | 2025-03-31 18:00:22 |
言語 | PyPy3 (7.3.15) |
結果 |
AC
|
実行時間 | 701 ms / 2,000 ms |
コード長 | 2,302 bytes |
コンパイル時間 | 347 ms |
コンパイル使用メモリ | 82,124 KB |
実行使用メモリ | 187,984 KB |
最終ジャッジ日時 | 2025-03-31 18:01:57 |
合計ジャッジ時間 | 32,853 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 46 |
ソースコード
MOD = 998244353 def main(): import sys input = sys.stdin.read data = input().split() ptr = 0 n = int(data[ptr]) ptr += 1 A = list(map(int, data[ptr:ptr + n])) ptr += n # Build adjacency list, 1-based edges = [[] for _ in range(n + 1)] for _ in range(n - 1): u = int(data[ptr]) v = int(data[ptr + 1]) ptr += 2 edges[u].append(v) edges[v].append(u) ans = [0] * (n + 1) dp = [0] * (n + 1) stack = [] # Stack elements: (node, parent, visited) stack.append((1, 0, False)) # Initialize structures to track intermediate values sum_child_res = [0] * (n + 1) cross = [0] * (n + 1) sum_child_dp = [0] * (n + 1) prev_sum = [0] * (n + 1) while stack: u, parent, visited = stack.pop() if not visited: # First visit: push back with visited=True, then push children stack.append((u, parent, True)) children = [] for v in edges[u]: if v != parent: children.append(v) # Push children in reverse order to process them in order for v in reversed(children): stack.append((v, u, False)) # Initialize intermediate values for node u sum_child_res[u] = 0 cross[u] = 0 sum_child_dp[u] = 0 prev_sum[u] = 0 else: # Process the node after all children are processed children = [v for v in edges[u] if v != parent] for v in children: sum_child_res[u] = (sum_child_res[u] + ans[v]) % MOD # Update cross value cross[u] = (cross[u] + dp[v] * prev_sum[u]) % MOD prev_sum[u] = (prev_sum[u] + dp[v]) % MOD sum_child_dp[u] = (sum_child_dp[u] + dp[v]) % MOD # Calculate current_ans and current_dp term1 = cross[u] * A[u-1] % MOD term2 = sum_child_dp[u] * A[u-1] % MOD current_ans = (sum_child_res[u] + term1 + term2) % MOD ans[u] = current_ans current_dp = (A[u-1] * (1 + sum_child_dp[u])) % MOD dp[u] = current_dp print(ans[1] % MOD) if __name__ == "__main__": main()