結果
| 問題 | No.3444 Interval Xor MST |
| コンテスト | |
| ユーザー |
👑 potato167
|
| 提出日時 | 2025-12-28 05:13:18 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
AC
|
| 実行時間 | 1,637 ms / 2,000 ms |
| コード長 | 2,917 bytes |
| 記録 | |
| コンパイル時間 | 3,463 ms |
| コンパイル使用メモリ | 81,904 KB |
| 実行使用メモリ | 107,556 KB |
| 最終ジャッジ日時 | 2026-02-06 20:51:15 |
| 合計ジャッジ時間 | 12,741 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | AC * 7 |
ソースコード
import sys
def minxor_prefix(a: int, b: int, bit: int) -> int:
"""
b > 0.
0 <= y < b の範囲で (a xor y) を最小化する値を返す。
0 <= a < 2^bit を仮定。bit は扱う下位ビット数。
"""
B = b - 1
INF = 10**30
dp_free = INF # すでに B より小さいことが確定(以降は自由)
dp_tight = 0 # ここまで B と一致(以降も上限に制約)
for i in range(bit - 1, -1, -1):
w = 1 << i
abit = (a >> i) & 1
bbit = (B >> i) & 1
ndp_free = INF
ndp_tight = INF
# free からは ybit を 0/1 どちらでも
if dp_free != INF:
ndp_free = min(ndp_free, dp_free + (abit * w)) # ybit=0
ndp_free = min(ndp_free, dp_free + ((abit ^ 1) * w)) # ybit=1
# tight からは ybit <= bbit
if bbit == 0:
# ybit=0 のみ、tight のまま
ndp_tight = dp_tight + (abit * w)
else:
# ybit=0 => free へ
ndp_free = min(ndp_free, dp_tight + (abit * w))
# ybit=1 => tight 維持
ndp_tight = min(ndp_tight, dp_tight + ((abit ^ 1) * w))
dp_free, dp_tight = ndp_free, ndp_tight
return dp_free if dp_free < dp_tight else dp_tight
def solve_interval(l: int, r: int, bit: int = 32) -> int:
"""
集合 {l, l+1, ..., r-1} 上の完全グラフで、辺重みが XOR のときの MST 重み和。
0 <= l < r <= 2^bit を前提に、上位 bit から分割して計算する。
"""
if r - l <= 1 or bit == 0:
return 0
# 全ブロック [0, 2^bit) のときは閉形式で返す(重要)
if l == 0 and r == (1 << bit):
return bit * (1 << (bit - 1))
half = 1 << (bit - 1)
# 下半分に完全に収まる
if r <= half:
return solve_interval(l, r, bit - 1)
# 上半分に完全に収まる(half を引いて下へ詰める)
if l >= half:
return solve_interval(l - half, r - half, bit - 1)
# 跨ぐ:左 [l, half), 右 [0, r-half)
left = solve_interval(l, half, bit - 1)
right_len = r - half
right = solve_interval(0, right_len, bit - 1)
a = l
b = right_len
if a < b:
delta = 0 # 重なりがあるので同じ値を選べて XOR 0
else:
# min_{0<=y<b} (a xor y)
delta = minxor_prefix(a, b, bit - 1)
return left + right + half + delta
def solve_case(N: int, M: int) -> int:
L = M
R = M + N
# 最大でも < 2^32 なので bit=32 で覆える
return solve_interval(L, R, 32)
def main():
data = list(map(int, sys.stdin.buffer.read().split()))
T = data[0]
out = []
idx = 1
for _ in range(T):
N = data[idx]; M = data[idx + 1]
idx += 2
out.append(str(solve_case(N, M)))
sys.stdout.write("\n".join(out))
if __name__ == "__main__":
main()
potato167