結果

問題 No.1240 Or Sum of Xor Pair
ユーザー gew1fw
提出日時 2025-06-12 13:12:27
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 1,888 bytes
コンパイル時間 171 ms
コンパイル使用メモリ 83,028 KB
実行使用メモリ 282,056 KB
最終ジャッジ日時 2025-06-12 13:15:49
合計ジャッジ時間 11,826 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 12 TLE * 1 -- * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

class TrieNode:
    __slots__ = ['children', 'count']
    def __init__(self):
        self.children = [None, None]
        self.count = 0

def insert_trie(root, num):
    node = root
    for bit in reversed(range(18)):  # Process from MSB to LSB (17 to 0)
        b = (num >> bit) & 1
        if not node.children[b]:
            node.children[b] = TrieNode()
        node = node.children[b]
        node.count += 1

def query_trie(root, a, X):
    res = 0
    node = root
    for bit in reversed(range(18)):
        if not node:
            break
        a_bit = (a >> bit) & 1
        x_bit = (X >> bit) & 1
        desired_trie_bit = a_bit ^ x_bit
        other_trie_bit = 1 - desired_trie_bit
        other_xor_bit = a_bit ^ other_trie_bit

        if other_xor_bit < x_bit:
            if node.children[other_trie_bit]:
                res += node.children[other_trie_bit].count

        if node.children[desired_trie_bit]:
            node = node.children[desired_trie_bit]
        else:
            node = None
    return res

def main():
    import sys
    input = sys.stdin.read().split()
    idx = 0
    N = int(input[idx])
    idx += 1
    X = int(input[idx])
    idx += 1
    A = list(map(int, input[idx:idx+N]))
    idx += N

    # Compute T: total valid pairs
    root = TrieNode()
    T = 0
    for num in A:
        T += query_trie(root, num, X)
        insert_trie(root, num)

    sum_total = 0
    for b in range(18):
        # Filter elements where bit b is not set
        filtered = [num for num in A if (num & (1 << b)) == 0]
        # Compute T_b for this filtered array
        root_b = TrieNode()
        T_b = 0
        for num in filtered:
            T_b += query_trie(root_b, num, X)
            insert_trie(root_b, num)
        contribution = (T - T_b) * (1 << b)
        sum_total += contribution

    print(sum_total)

if __name__ == "__main__":
    main()
0