結果
| 問題 |
No.1240 Or Sum of Xor Pair
|
| コンテスト | |
| ユーザー |
gew1fw
|
| 提出日時 | 2025-06-12 14:07:56 |
| 言語 | PyPy3 (7.3.15) |
| 結果 |
WA
|
| 実行時間 | - |
| コード長 | 2,320 bytes |
| コンパイル時間 | 245 ms |
| コンパイル使用メモリ | 82,844 KB |
| 実行使用メモリ | 214,248 KB |
| 最終ジャッジ日時 | 2025-06-12 14:08:27 |
| 合計ジャッジ時間 | 14,399 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 3 |
| other | AC * 11 WA * 1 TLE * 1 -- * 17 |
ソースコード
import sys
from sys import stdin
class TrieNode:
def __init__(self):
self.left = None # Represents 0
self.right = None # Represents 1
self.count = 0
def query_trie(node, a, X, bit, state):
if node is None:
return 0
if bit < 0:
return 1 if state == 'lt' else 0
bit_a = (a >> bit) & 1
bit_x = (X >> bit) & 1
if state == 'lt':
return node.count
elif state == 'eq':
total = 0
for b_bit in [0, 1]:
xor_bit = bit_a ^ b_bit
if xor_bit < bit_x:
next_state = 'lt'
elif xor_bit == bit_x:
next_state = 'eq'
else:
continue # No need to proceed for this b_bit
if b_bit == 0:
if node.left:
total += query_trie(node.left, a, X, bit - 1, next_state)
else:
if node.right:
total += query_trie(node.right, a, X, bit - 1, next_state)
return total
else: # state == 'gt'
return 0
def insert_trie(node, a):
current = node
for bit in range(17, -1, -1):
current_bit = (a >> bit) & 1
if current_bit == 0:
if not current.left:
current.left = TrieNode()
current = current.left
else:
if not current.right:
current.right = TrieNode()
current = current.right
current.count += 1
def main():
input = sys.stdin.read().split()
idx = 0
N = int(input[idx])
idx += 1
X = int(input[idx])
idx += 1
A = list(map(int, input[idx:idx+N]))
idx += N
total = 0
root = TrieNode()
for a in A:
cnt = query_trie(root, a, X, 17, 'eq')
total += cnt
insert_trie(root, a)
sum_result = 0
for k in range(18):
B = [a for a in A if (a >> k) & 1 == 0]
if len(B) < 2:
continue
sub_root = TrieNode()
count_both_unset = 0
for a in B:
cnt = query_trie(sub_root, a, X, 17, 'eq')
count_both_unset += cnt
insert_trie(sub_root, a)
contribution = (total - count_both_unset) * (1 << k)
sum_result += contribution
print(sum_result)
if __name__ == "__main__":
main()
gew1fw