結果
| 問題 | 
                            No.1240 Or Sum of Xor Pair
                             | 
                    
| コンテスト | |
| ユーザー | 
                             lam6er
                         | 
                    
| 提出日時 | 2025-04-15 23:31:47 | 
| 言語 | PyPy3  (7.3.15)  | 
                    
| 結果 | 
                             
                                TLE
                                 
                             
                            
                         | 
                    
| 実行時間 | - | 
| コード長 | 1,854 bytes | 
| コンパイル時間 | 199 ms | 
| コンパイル使用メモリ | 81,648 KB | 
| 実行使用メモリ | 286,280 KB | 
| 最終ジャッジ日時 | 2025-04-15 23:33:13 | 
| 合計ジャッジ時間 | 11,575 ms | 
| 
                            ジャッジサーバーID (参考情報)  | 
                        judge2 / judge4 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 3 | 
| other | AC * 12 TLE * 1 -- * 17 | 
ソースコード
class TrieNode:
    __slots__ = ['children', 'cnt']
    def __init__(self):
        self.children = [None, None]
        self.cnt = 0
def compute_pairs(arr, x):
    root = TrieNode()
    total = 0
    for a in arr:
        current_count = 0
        stack = [(root, 17, True)]
        while stack:
            node, bit, is_tight = stack.pop()
            if bit < 0:
                continue
            a_bit = (a >> bit) & 1
            x_bit = (x >> bit) & 1
            for child_bit in [0, 1]:
                child = node.children[child_bit]
                if not child:
                    continue
                xor_bit = a_bit ^ child_bit
                if is_tight:
                    if xor_bit < x_bit:
                        current_count += child.cnt
                    elif xor_bit == x_bit:
                        stack.append((child, bit - 1, True))
                else:
                    current_count += child.cnt
        total += current_count
        current = root
        for i in reversed(range(18)):
            b = (a >> i) & 1
            if not current.children[b]:
                current.children[b] = TrieNode()
            current = current.children[b]
            current.cnt += 1
    return total
def main():
    import sys
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr += 1
    X = int(input[ptr])
    ptr += 1
    A = list(map(int, input[ptr:ptr + N]))
    ptr += N
    total_valid = compute_pairs(A, X)
    bits = []
    for k in range(18):
        mask = 1 << k
        B = [a for a in A if (a & mask) == 0]
        bits.append(B)
    ans = 0
    for k in range(18):
        B_k = bits[k]
        cnt_k = compute_pairs(B_k, X)
        contribution = (total_valid - cnt_k) * (1 << k)
        ans += contribution
    print(ans)
if __name__ == "__main__":
    main()
            
            
            
        
            
lam6er