結果
| 問題 |
No.2377 SUM AND XOR on Tree
|
| コンテスト | |
| ユーザー |
lam6er
|
| 提出日時 | 2025-03-20 20:58:46 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
AC
|
| 実行時間 | 1,350 ms / 4,000 ms |
| コード長 | 2,499 bytes |
| コンパイル時間 | 438 ms |
| コンパイル使用メモリ | 82,172 KB |
| 実行使用メモリ | 202,844 KB |
| 最終ジャッジ日時 | 2025-03-20 21:00:09 |
| 合計ジャッジ時間 | 23,112 ms |
|
ジャッジサーバーID (参考情報) |
judge1 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 33 |
ソースコード
MOD = 998244353
def main():
import sys
input = sys.stdin.read
data = input().split()
idx = 0
n = int(data[idx])
idx += 1
edges = [[] for _ in range(n+1)]
for _ in range(n-1):
u = int(data[idx])
v = int(data[idx+1])
edges[u].append(v)
edges[v].append(u)
idx += 2
A = list(map(int, data[idx:idx+n]))
idx += n
# Build the tree structure with parent pointers and children list
root = 1
parent = [0] * (n + 1)
children = [[] for _ in range(n + 1)]
stack = [(root, 0)]
while stack:
u, p = stack.pop()
parent[u] = p
for v in edges[u]:
if v != p:
children[u].append(v)
stack.append((v, u))
total_ans = 0
for b in range(30):
# Compute a[u] for each node u (bit b)
a = [0] * (n + 1)
sum_xor = 0
for u in range(1, n+1):
au = (A[u-1] >> b) & 1
a[u] = au
sum_xor ^= au
dp0 = [0] * (n + 1)
dp1 = [0] * (n + 1)
# Iterative post-order traversal
stack = [(root, False)]
while stack:
u, visited = stack.pop()
if not visited:
stack.append((u, True))
for v in reversed(children[u]):
stack.append((v, False))
else:
# Initialize with node u's a value
if a[u] == 0:
curr0, curr1 = 1, 0
else:
curr0, curr1 = 0, 1
for v in children[u]:
v0 = dp0[v]
v1 = dp1[v]
# Disconnect: curr * v1
new0 = (curr0 * v1) % MOD
new1 = (curr1 * v1) % MOD
# Connect: curr * v0 and curr * v1, for both curr0 and curr1
con0 = (curr0 * v0 + curr1 * v1) % MOD
con1 = (curr0 * v1 + curr1 * v0) % MOD
new0 = (new0 + con0) % MOD
new1 = (new1 + con1) % MOD
curr0, curr1 = new0, new1
dp0[u], dp1[u] = curr0, curr1
ways = dp1[root]
total_ans = (total_ans + (ways << b)) % MOD
print(total_ans % MOD)
if __name__ == "__main__":
main()
lam6er