結果

問題 No.1240 Or Sum of Xor Pair
ユーザー gew1fw
提出日時 2025-06-12 13:11:30
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,004 bytes
コンパイル時間 202 ms
コンパイル使用メモリ 82,652 KB
実行使用メモリ 276,084 KB
最終ジャッジ日時 2025-06-12 13:14:47
合計ジャッジ時間 12,980 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 12 TLE * 1 -- * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

class TrieNode:
    __slots__ = ['children', 'count']
    def __init__(self):
        self.children = [None, None]
        self.count = 0

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, num):
        node = self.root
        for bit in reversed(range(18)):
            b = (num >> bit) & 1
            if not node.children[b]:
                node.children[b] = TrieNode()
            node = node.children[b]
            node.count += 1

    def query(self, num, x):
        node = self.root
        count = 0
        for bit in reversed(range(18)):
            if not node:
                break
            a_bit = (num >> bit) & 1
            x_bit = (x >> bit) & 1
            for child_bit in [0, 1]:
                if node.children[child_bit]:
                    xor_bit = a_bit ^ child_bit
                    if xor_bit < x_bit:
                        count += node.children[child_bit].count
            desired_child_bit = a_bit ^ x_bit
            if node.children[desired_child_bit]:
                node = node.children[desired_child_bit]
            else:
                node = None
        return count

def main():
    import sys
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr += 1
    X = int(input[ptr])
    ptr += 1
    A = list(map(int, input[ptr:ptr+N]))
    ptr += N

    # Compute total_valid_pairs
    trie_total = Trie()
    total_valid = 0
    for a in A:
        total_valid += trie_total.query(a, X)
        trie_total.insert(a)

    ans = 0
    for b in range(18):
        # Filter elements where b-th bit is 0
        filtered = [a for a in A if (a & (1 << b)) == 0]
        count_b = 0
        if len(filtered) >= 2:
            trie_b = Trie()
            for a in filtered:
                count_b += trie_b.query(a, X)
                trie_b.insert(a)
        contribution = (total_valid - count_b) * (1 << b)
        ans += contribution
    print(ans)

if __name__ == "__main__":
    main()
0