結果

問題 No.1240 Or Sum of Xor Pair
ユーザー lam6er
提出日時 2025-04-15 23:30:47
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,119 bytes
コンパイル時間 493 ms
コンパイル使用メモリ 81,696 KB
実行使用メモリ 273,860 KB
最終ジャッジ日時 2025-04-15 23:32:24
合計ジャッジ時間 11,078 ms
ジャッジサーバーID
(参考情報)
judge3 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 12 TLE * 1 -- * 17
権限があれば一括ダウンロードができます

ソースコード

diff #

class FastTrie:
    def __init__(self):
        self.nodes = [[0, -1, -1]]  # Each node is [count, left_child, right_child]

    def insert(self, number):
        idx = 0
        for bit in reversed(range(18)):
            b = (number >> bit) & 1
            if self.nodes[idx][b + 1] == -1:
                self.nodes.append([0, -1, -1])
                self.nodes[idx][b + 1] = len(self.nodes) - 1
            idx = self.nodes[idx][b + 1]
            self.nodes[idx][0] += 1

    def query(self, a_j, X):
        stack = [(0, 17, False)]
        result = 0
        while stack:
            node_idx, bit, is_less = stack.pop()
            if bit < 0:
                continue
            if is_less:
                result += self.nodes[node_idx][0]
                continue
            j_bit = (a_j >> bit) & 1
            x_bit = (X >> bit) & 1
            for i_bit in [0, 1]:
                child_idx = self.nodes[node_idx][i_bit + 1]
                if child_idx == -1:
                    continue
                xor_bit = i_bit ^ j_bit
                if xor_bit < x_bit:
                    result += self.nodes[child_idx][0]
                elif xor_bit == x_bit:
                    stack.append((child_idx, bit - 1, False))
        return result

def main():
    import sys
    input = sys.stdin.read().split()
    ptr = 0
    N, X = int(input[ptr]), int(input[ptr + 1])
    ptr += 2
    A = list(map(int, input[ptr:ptr + N]))
    
    bits = 18
    S = [[] for _ in range(bits)]
    for a in A:
        for k in range(bits):
            if (a & (1 << k)) == 0:
                S[k].append(a)
    
    trie_T = FastTrie()
    T = 0
    for a in A:
        T += trie_T.query(a, X)
        trie_T.insert(a)
    
    ans = 0
    for k in range(bits):
        S_k = S[k]
        if len(S_k) < 2:
            C_k = 0
        else:
            trie_C = FastTrie()
            C_k = 0
            for a in S_k:
                C_k += trie_C.query(a, X)
                trie_C.insert(a)
        contribution = (T - C_k) * (1 << k)
        ans += contribution
    print(ans)

if __name__ == '__main__':
    main()
0