結果

問題 No.1240 Or Sum of Xor Pair
ユーザー gew1fw
提出日時 2025-06-12 19:04:10
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,320 bytes
コンパイル時間 318 ms
コンパイル使用メモリ 82,488 KB
実行使用メモリ 214,044 KB
最終ジャッジ日時 2025-06-12 19:04:27
合計ジャッジ時間 15,969 ms
ジャッジサーバーID
(参考情報)
judge1 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 10 WA * 1 TLE * 2 -- * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
from sys import stdin

class TrieNode:
    def __init__(self):
        self.left = None  # Represents 0
        self.right = None  # Represents 1
        self.count = 0

def query_trie(node, a, X, bit, state):
    if node is None:
        return 0
    if bit < 0:
        return 1 if state == 'lt' else 0
    bit_a = (a >> bit) & 1
    bit_x = (X >> bit) & 1
    if state == 'lt':
        return node.count
    elif state == 'eq':
        total = 0
        for b_bit in [0, 1]:
            xor_bit = bit_a ^ b_bit
            if xor_bit < bit_x:
                next_state = 'lt'
            elif xor_bit == bit_x:
                next_state = 'eq'
            else:
                continue  # No need to proceed for this b_bit
            if b_bit == 0:
                if node.left:
                    total += query_trie(node.left, a, X, bit - 1, next_state)
            else:
                if node.right:
                    total += query_trie(node.right, a, X, bit - 1, next_state)
        return total
    else:  # state == 'gt'
        return 0

def insert_trie(node, a):
    current = node
    for bit in range(17, -1, -1):
        current_bit = (a >> bit) & 1
        if current_bit == 0:
            if not current.left:
                current.left = TrieNode()
            current = current.left
        else:
            if not current.right:
                current.right = TrieNode()
            current = current.right
        current.count += 1

def main():
    input = sys.stdin.read().split()
    idx = 0
    N = int(input[idx])
    idx += 1
    X = int(input[idx])
    idx += 1
    A = list(map(int, input[idx:idx+N]))
    idx += N

    total = 0
    root = TrieNode()
    for a in A:
        cnt = query_trie(root, a, X, 17, 'eq')
        total += cnt
        insert_trie(root, a)

    sum_result = 0
    for k in range(18):
        B = [a for a in A if (a >> k) & 1 == 0]
        if len(B) < 2:
            continue

        sub_root = TrieNode()
        count_both_unset = 0
        for a in B:
            cnt = query_trie(sub_root, a, X, 17, 'eq')
            count_both_unset += cnt
            insert_trie(sub_root, a)

        contribution = (total - count_both_unset) * (1 << k)
        sum_result += contribution

    print(sum_result)

if __name__ == "__main__":
    main()
0