結果
| 問題 |
No.2377 SUM AND XOR on Tree
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-31 17:55:43 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 1,083 ms / 4,000 ms |
| コード長 | 2,447 bytes |
| コンパイル時間 | 287 ms |
| コンパイル使用メモリ | 82,660 KB |
| 実行使用メモリ | 208,488 KB |
| 最終ジャッジ日時 | 2025-03-31 17:57:08 |
| 合計ジャッジ時間 | 20,954 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 33 |
ソースコード
import sys
MOD = 998244353
def main():
sys.setrecursionlimit(1 << 25)
N = int(sys.stdin.readline())
edges = [[] for _ in range(N+1)]
for _ in range(N-1):
u, v = map(int, sys.stdin.readline().split())
edges[u].append(v)
edges[v].append(u)
A = list(map(int, sys.stdin.readline().split()))
A = [0] + A # Convert to 1-based indexing
# Build the tree structure with parent and children using BFS
parent = [0] * (N + 1)
children = [[] for _ in range(N + 1)]
root = 1
stack = [(root, 0)]
while stack:
u, p = stack.pop()
parent[u] = p
for v in edges[u]:
if v != p:
children[u].append(v)
stack.append((v, u))
ans = 0
for b in range(30):
# Precompute s_b for each node
s = [0] * (N + 1)
for u in range(1, N + 1):
s[u] = (A[u] >> b) & 1
dp0 = [0] * (N + 1)
dp1 = [0] * (N + 1)
# Iterative post-order traversal
stack = [(root, False)]
while stack:
u, visited = stack.pop()
if not visited:
stack.append((u, True))
# Push children in reverse order to process them left to right
for v in reversed(children[u]):
stack.append((v, False))
else:
# Initialize current DP with the value of the node's s
if s[u] == 0:
curr0, curr1 = 1, 0
else:
curr0, curr1 = 0, 1
# Merge each child's DP
for v in children[u]:
v0 = dp0[v]
v1 = dp1[v]
# Cut the edge to child (contributes v's 1)
new0_cut = curr0 * v1 % MOD
new1_cut = curr1 * v1 % MOD
# Not cut the edge (merge with child's parity)
nc0 = (curr0 * v0 + curr1 * v1) % MOD
nc1 = (curr0 * v1 + curr1 * v0) % MOD
# Update current DP
curr0 = (new0_cut + nc0) % MOD
curr1 = (new1_cut + nc1) % MOD
dp0[u] = curr0
dp1[u] = curr1
# Add the contribution of this bit
ans = (ans + (dp1[root] % MOD) * ((1 << b) % MOD)) % MOD
print(ans % MOD)
if __name__ == '__main__':
main()
lam6er