結果
| 問題 |
No.2377 SUM AND XOR on Tree
|
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2024-09-05 19:58:50 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 2,963 bytes |
| コンパイル時間 | 388 ms |
| コンパイル使用メモリ | 82,296 KB |
| 実行使用メモリ | 278,864 KB |
| 最終ジャッジ日時 | 2024-09-05 19:58:59 |
| 合計ジャッジ時間 | 8,387 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 2 |
| other | AC * 6 TLE * 1 -- * 26 |
ソースコード
## https://yukicoder.me/problems/no/2377
from collections import deque
MOD = 998244353
def solve(N, next_nodes, bit_a_list):
# (現役で繋がろうとしている成分の現時点の1xor, すでに完結している連結成分たちの1xor のAnd)
# 0: (0, 0), 1: (0, 1), 2: (1, 0), 3: (1, 1)
dp = [[0, 0, 0, 0] for _ in range(N)]
parents = [-2] * N
parents[0] = -1
stack = deque()
stack.append((0, 0))
while len(stack) > 0:
v, index = stack.pop()
if index == 0:
dp[v][bit_a_list[v] * 2 + 1] = 1
else:
w = next_nodes[v][index - 1]
new_dp = [0] * 4
for state in range(4):
state_num = dp[w][state]
c_num = state // 2
f_num = state % 2
for base_state in range(4):
base_state_num = dp[v][base_state]
base_c_num = base_state // 2
base_f_num = base_state % 2
# つなげる
new_c_num = base_c_num ^ c_num
new_f_num = f_num & base_f_num
new_state = new_c_num * 2 + new_f_num
new_dp[new_state] += (state_num * base_state_num) % MOD
new_dp[new_state] %= MOD
# 独立させる
new_c_num = base_c_num
new_f_num = f_num & base_f_num & c_num
new_state = new_c_num * 2 + new_f_num
new_dp[new_state] += (state_num * base_state_num) % MOD
new_dp[new_state] %= MOD
dp[v] = new_dp
while index < len(next_nodes[v]):
w = next_nodes[v][index]
if w == parents[v]:
index += 1
continue
parents[w] = v
stack.append((v, index + 1))
stack.append((w, 0))
break
# 最後の仕上げ
answer = 0
for state in range(4):
c_num = state // 2
f_num = state % 2
ans = c_num & f_num
if ans == 1:
answer += dp[0][state]
answer %= MOD
return answer
def main():
N = int(input())
next_nodes = [[] for _ in range(N)]
for _ in range(N - 1):
u, v = map(int, input().split())
next_nodes[u - 1].append(v - 1)
next_nodes[v - 1].append(u - 1)
A = list(map(int, input().split()))
# Aの各ビットごとに1の数を計算
max_a = max(A)
k = 0
while (1 << k) < max_a:
k += 1
max_k = k
answer = 0
for k in range(max_k + 1):
bit_a_list = [0] * N
for i in range(N):
a = A[i]
bit_a_list[i] = 1 if a & (1 << k) > 0 else 0
count = solve(N, next_nodes, bit_a_list)
answer += (count * (1 << k)) % MOD
answer %= MOD
print(answer)
if __name__ == '__main__':
main()