結果

問題 No.1240 Or Sum of Xor Pair
ユーザー lam6er
提出日時 2025-03-31 17:31:00
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 1,901 bytes
コンパイル時間 299 ms
コンパイル使用メモリ 82,036 KB
実行使用メモリ 279,984 KB
最終ジャッジ日時 2025-03-31 17:32:01
合計ジャッジ時間 11,397 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 12 TLE * 1 -- * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

class TrieNode:
    __slots__ = ['children', 'count']
    def __init__(self):
        self.children = [None, None]
        self.count = 0

def insert(root, num, max_bits):
    current = root
    for k in reversed(range(max_bits)):
        bit = (num >> k) & 1
        if not current.children[bit]:
            current.children[bit] = TrieNode()
        current = current.children[bit]
        current.count += 1

def query(root, num, x, max_bits):
    current = root
    count = 0
    for k in reversed(range(max_bits)):
        if not current:
            break
        a_bit = (num >> k) & 1
        x_bit = (x >> k) & 1
        if x_bit:
            if current.children[a_bit]:
                count += current.children[a_bit].count
            current = current.children[1 - a_bit]
        else:
            current = current.children[a_bit]
    return count

def main():
    import sys
    input = sys.stdin.read().split()
    idx = 0
    N, X = int(input[idx]), int(input[idx+1])
    idx += 2
    A = list(map(int, input[idx:idx+N]))
    idx += N
    
    max_bits = 18
    
    # Precompute S_b for each bit
    pre_s = [[] for _ in range(max_bits)]
    for num in A:
        for b in range(max_bits):
            if (num & (1 << b)) == 0:
                pre_s[b].append(num)
    
    # Compute total count C
    c_root = TrieNode()
    total_pairs = 0
    for num in A:
        total_pairs += query(c_root, num, X, max_bits)
        insert(c_root, num, max_bits)
    
    result = 0
    # Compute each D_b and accumulate the result
    for b in range(max_bits):
        s_b = pre_s[b]
        d_b = 0
        d_root = TrieNode()
        for num in s_b:
            d_b += query(d_root, num, X, max_bits)
            insert(d_root, num, max_bits)
        contribution = (total_pairs - d_b) * (1 << b)
        result += contribution
    
    print(result)

if __name__ == "__main__":
    main()
0