結果
| 問題 | No.3444 Interval Xor MST |
| コンテスト | |
| ユーザー |
👑 potato167
|
| 提出日時 | 2025-12-28 05:19:49 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
AC
|
| 実行時間 | 1,486 ms / 2,000 ms |
| コード長 | 2,914 bytes |
| 記録 | |
| コンパイル時間 | 4,337 ms |
| コンパイル使用メモリ | 82,668 KB |
| 実行使用メモリ | 107,648 KB |
| 最終ジャッジ日時 | 2026-02-06 20:51:15 |
| 合計ジャッジ時間 | 11,031 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge3 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 7 |
ソースコード
import sys
def minxor_prefix(a: int, b: int, bits: int) -> int:
"""
0 <= y < b の範囲で (a xor y) を最小化する値を返す。
0 <= a < 2^bits を仮定。bits は扱う下位ビット数。
"""
# b >= 1
B = b - 1
INF = 10**30
dp_free = INF # 既に y <= B が strict に成り立ち、以降は自由
dp_tight = 0 # ここまで y の上位が B と一致(tight)
for i in range(bits - 1, -1, -1):
w = 1 << i
abit = (a >> i) & 1
bbit = (B >> i) & 1
ndp_free = INF
ndp_tight = INF
# free: ybit は 0/1 どちらでもよい
if dp_free != INF:
ndp_free = min(ndp_free, dp_free + (abit * w)) # ybit=0
ndp_free = min(ndp_free, dp_free + ((abit ^ 1) * w)) # ybit=1
# tight: ybit <= bbit
if bbit == 0:
# ybit=0 のみ、tight 維持
ndp_tight = dp_tight + (abit * w)
else:
# ybit=0 で free へ
ndp_free = min(ndp_free, dp_tight + (abit * w))
# ybit=1 で tight 維持
ndp_tight = min(ndp_tight, dp_tight + ((abit ^ 1) * w))
dp_free, dp_tight = ndp_free, ndp_tight
return dp_free if dp_free < dp_tight else dp_tight
def mst_full_block(bit: int) -> int:
"""
集合 [0, 2^bit) の XOR 完全グラフの MST 重み和は bit * 2^(bit-1)
"""
if bit <= 0:
return 0
return bit * (1 << (bit - 1))
def solve_interval(l: int, r: int, bit: int) -> int:
"""
集合 {l, ..., r-1} 上の XOR 完全グラフ MST 重み和。
前提: 0 <= l < r <= 2^bit
"""
if r - l <= 1 or bit == 0:
return 0
# 全ブロックなら閉形式
if l == 0 and r == (1 << bit):
return mst_full_block(bit)
half = 1 << (bit - 1)
if r <= half:
return solve_interval(l, r, bit - 1)
if l >= half:
return solve_interval(l - half, r - half, bit - 1)
# split
left = solve_interval(l, half, bit - 1)
right_len = r - half
right = solve_interval(0, right_len, bit - 1)
a = l
b = right_len
# 追加で 1 本つなぐコスト: 2^(bit-1) + delta
# delta = min_{x in [a,2^(bit-1)), y in [0,b)} (x xor y)
# a < b なら重なって delta=0
if a < b:
delta = 0
else:
delta = minxor_prefix(a, b, bit - 1)
return left + right + half + delta
def solve_case(N: int, M: int) -> int:
L = M
R = M + N
# M,N <= 2e9 より R < 4e9 < 2^32、bit=32 で被覆
return solve_interval(L, R, 32)
def main() -> None:
data = list(map(int, sys.stdin.buffer.read().split()))
T = data[0]
out = []
idx = 1
for _ in range(T):
N = data[idx]
M = data[idx + 1]
idx += 2
out.append(str(solve_case(N, M)))
sys.stdout.write("\n".join(out))
if __name__ == "__main__":
main()
potato167