結果

問題 No.2377 SUM AND XOR on Tree
ユーザー lam6er
提出日時 2025-03-26 15:52:18
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 2,826 ms / 4,000 ms
コード長 2,148 bytes
コンパイル時間 284 ms
コンパイル使用メモリ 82,628 KB
実行使用メモリ 279,672 KB
最終ジャッジ日時 2025-03-26 15:53:39
合計ジャッジ時間 44,200 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 2
other AC * 33
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from collections import deque

MOD = 998244353

def main():
    sys.setrecursionlimit(1 << 25)
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr += 1
    edges = [[] for _ in range(N)]
    for _ in range(N-1):
        u = int(input[ptr])-1
        v = int(input[ptr+1])-1
        ptr +=2
        edges[u].append(v)
        edges[v].append(u)
    A = list(map(int, input[ptr:ptr+N]))
    ptr += N
    
    # Build parent and children
    root = 0
    parent = [ -1 ] * N
    children = [[] for _ in range(N)]
    q = deque([root])
    parent[root] = -1
    while q:
        u = q.popleft()
        for v in edges[u]:
            if parent[v] == -1 and v != parent[u]:
                parent[v] = u
                children[u].append(v)
                q.append(v)
    
    # Compute post_order
    post_order = []
    stack = [(root, False)]
    while stack:
        node, visited = stack.pop()
        if visited:
            post_order.append(node)
            continue
        stack.append((node, True))
        for child in reversed(children[node]):
            stack.append((child, False))
    
    total = 0
    for b in range(30):
        a_ub = [ (A[i] >> b) & 1 for i in range(N) ]
        S = sum(a_ub)
        if S == 0:
            continue
        
        dp = [ [0]*2 for _ in range(N) ]
        for u in post_order:
            current_parity = a_ub[u]
            dp[u][current_parity] = 1
            for v in children[u]:
                new_dp = [0, 0]
                for p in 0, 1:
                    if dp[u][p] == 0:
                        continue
                    # Option 1: cut edge
                    new_dp[p] = (new_dp[p] + dp[u][p] * dp[v][1]) % MOD
                    # Option 2: not cut
                    for q in 0, 1:
                        new_p = p ^ q
                        new_dp[new_p] = (new_dp[new_p] + dp[u][p] * dp[v][q]) % MOD
                dp[u][0], dp[u][1] = new_dp[0] % MOD, new_dp[1] % MOD
        
        ans_b = dp[root][1]
        total = (total + ans_b * (1 << b)) % MOD
    
    print(total % MOD)

if __name__ == '__main__':
    main()
0