結果
問題 |
No.1240 Or Sum of Xor Pair
|
ユーザー |
![]() |
提出日時 | 2025-04-16 16:09:08 |
言語 | PyPy3 (7.3.15) |
結果 |
TLE
|
実行時間 | - |
コード長 | 2,119 bytes |
コンパイル時間 | 389 ms |
コンパイル使用メモリ | 82,132 KB |
実行使用メモリ | 266,864 KB |
最終ジャッジ日時 | 2025-04-16 16:16:08 |
合計ジャッジ時間 | 12,531 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 3 |
other | AC * 12 TLE * 1 -- * 17 |
ソースコード
class FastTrie: def __init__(self): self.nodes = [[0, -1, -1]] # Each node is [count, left_child, right_child] def insert(self, number): idx = 0 for bit in reversed(range(18)): b = (number >> bit) & 1 if self.nodes[idx][b + 1] == -1: self.nodes.append([0, -1, -1]) self.nodes[idx][b + 1] = len(self.nodes) - 1 idx = self.nodes[idx][b + 1] self.nodes[idx][0] += 1 def query(self, a_j, X): stack = [(0, 17, False)] result = 0 while stack: node_idx, bit, is_less = stack.pop() if bit < 0: continue if is_less: result += self.nodes[node_idx][0] continue j_bit = (a_j >> bit) & 1 x_bit = (X >> bit) & 1 for i_bit in [0, 1]: child_idx = self.nodes[node_idx][i_bit + 1] if child_idx == -1: continue xor_bit = i_bit ^ j_bit if xor_bit < x_bit: result += self.nodes[child_idx][0] elif xor_bit == x_bit: stack.append((child_idx, bit - 1, False)) return result def main(): import sys input = sys.stdin.read().split() ptr = 0 N, X = int(input[ptr]), int(input[ptr + 1]) ptr += 2 A = list(map(int, input[ptr:ptr + N])) bits = 18 S = [[] for _ in range(bits)] for a in A: for k in range(bits): if (a & (1 << k)) == 0: S[k].append(a) trie_T = FastTrie() T = 0 for a in A: T += trie_T.query(a, X) trie_T.insert(a) ans = 0 for k in range(bits): S_k = S[k] if len(S_k) < 2: C_k = 0 else: trie_C = FastTrie() C_k = 0 for a in S_k: C_k += trie_C.query(a, X) trie_C.insert(a) contribution = (T - C_k) * (1 << k) ans += contribution print(ans) if __name__ == '__main__': main()