結果
問題 |
No.1240 Or Sum of Xor Pair
|
ユーザー |
![]() |
提出日時 | 2025-04-15 23:33:23 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 1,854 bytes |
コンパイル時間 | 172 ms |
コンパイル使用メモリ | 82,388 KB |
実行使用メモリ | 286,104 KB |
最終ジャッジ日時 | 2025-04-15 23:34:36 |
合計ジャッジ時間 | 10,961 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 12 TLE * 1 -- * 17 |
ソースコード
class TrieNode: __slots__ = ['children', 'cnt'] def __init__(self): self.children = [None, None] self.cnt = 0 def compute_pairs(arr, x): root = TrieNode() total = 0 for a in arr: current_count = 0 stack = [(root, 17, True)] while stack: node, bit, is_tight = stack.pop() if bit < 0: continue a_bit = (a >> bit) & 1 x_bit = (x >> bit) & 1 for child_bit in [0, 1]: child = node.children[child_bit] if not child: continue xor_bit = a_bit ^ child_bit if is_tight: if xor_bit < x_bit: current_count += child.cnt elif xor_bit == x_bit: stack.append((child, bit - 1, True)) else: current_count += child.cnt total += current_count current = root for i in reversed(range(18)): b = (a >> i) & 1 if not current.children[b]: current.children[b] = TrieNode() current = current.children[b] current.cnt += 1 return total def main(): import sys input = sys.stdin.read().split() ptr = 0 N = int(input[ptr]) ptr += 1 X = int(input[ptr]) ptr += 1 A = list(map(int, input[ptr:ptr + N])) ptr += N total_valid = compute_pairs(A, X) bits = [] for k in range(18): mask = 1 << k B = [a for a in A if (a & mask) == 0] bits.append(B) ans = 0 for k in range(18): B_k = bits[k] cnt_k = compute_pairs(B_k, X) contribution = (total_valid - cnt_k) * (1 << k) ans += contribution print(ans) if __name__ == "__main__": main()