結果

問題 No.1240 Or Sum of Xor Pair
ユーザー lam6er
提出日時 2025-03-20 20:30:38
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 1,992 bytes
コンパイル時間 140 ms
コンパイル使用メモリ 82,336 KB
実行使用メモリ 304,136 KB
最終ジャッジ日時 2025-03-20 20:31:36
合計ジャッジ時間 9,201 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 12 TLE * 1 -- * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

class TrieNode:
    __slots__ = ['count', 'children']
    def __init__(self):
        self.count = 0
        self.children = [None, None]

def compute_S(arr, X):
    max_bit = 17  # 0-based index for 18 bits
    root = TrieNode()
    total = 0
    for num in arr:
        current_node = root
        current_count = 0
        node = root
        for d in range(max_bit, -1, -1):
            if node is None:
                break
            bit_num = (num >> d) & 1
            bit_x = (X >> d) & 1
            if bit_x == 1:
                same_bit = bit_num
                if node.children[same_bit] is not None:
                    current_count += node.children[same_bit].count
                opposite_bit = 1 - bit_num
                node = node.children[opposite_bit]
            else:
                same_bit = bit_num
                node = node.children[same_bit]
        total += current_count
        # Insert current number into Trie
        insert_node = root
        for d in range(max_bit, -1, -1):
            bit = (num >> d) & 1
            if insert_node.children[bit] is None:
                insert_node.children[bit] = TrieNode()
            insert_node = insert_node.children[bit]
            insert_node.count += 1
    return total

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N, X = int(input[ptr]), int(input[ptr+1])
    ptr += 2
    A = list(map(int, input[ptr:ptr+N]))
    if X == 0:
        print(0)
        return
    # Preprocess each bit's list
    bits = [[] for _ in range(18)]
    for num in A:
        for k in range(18):
            if (num & (1 << k)) == 0:
                bits[k].append(num)
    S = compute_S(A, X)
    if S == 0:
        print(0)
        return
    ans = 0
    for k in range(18):
        Bk = bits[k]
        if len(Bk) < 2:
            Zk = 0
        else:
            Zk = compute_S(Bk, X)
        Sk = S - Zk
        ans += Sk * (1 << k)
    print(ans)

if __name__ == "__main__":
    main()
0