結果

問題 No.3444 Interval Xor MST
コンテスト
ユーザー 👑 potato167
提出日時 2025-12-28 05:19:49
言語 PyPy3
(7.3.17)
結果
AC  
実行時間 1,486 ms / 2,000 ms
コード長 2,914 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 4,337 ms
コンパイル使用メモリ 82,668 KB
実行使用メモリ 107,648 KB
最終ジャッジ日時 2026-02-06 20:51:15
合計ジャッジ時間 11,031 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 7
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys

def minxor_prefix(a: int, b: int, bits: int) -> int:
    """
    0 <= y < b の範囲で (a xor y) を最小化する値を返す。
    0 <= a < 2^bits を仮定。bits は扱う下位ビット数。
    """
    # b >= 1
    B = b - 1
    INF = 10**30

    dp_free = INF   # 既に y <= B が strict に成り立ち、以降は自由
    dp_tight = 0    # ここまで y の上位が B と一致(tight)

    for i in range(bits - 1, -1, -1):
        w = 1 << i
        abit = (a >> i) & 1
        bbit = (B >> i) & 1

        ndp_free = INF
        ndp_tight = INF

        # free: ybit は 0/1 どちらでもよい
        if dp_free != INF:
            ndp_free = min(ndp_free, dp_free + (abit * w))           # ybit=0
            ndp_free = min(ndp_free, dp_free + ((abit ^ 1) * w))     # ybit=1

        # tight: ybit <= bbit
        if bbit == 0:
            # ybit=0 のみ、tight 維持
            ndp_tight = dp_tight + (abit * w)
        else:
            # ybit=0 で free へ
            ndp_free = min(ndp_free, dp_tight + (abit * w))
            # ybit=1 で tight 維持
            ndp_tight = min(ndp_tight, dp_tight + ((abit ^ 1) * w))

        dp_free, dp_tight = ndp_free, ndp_tight

    return dp_free if dp_free < dp_tight else dp_tight


def mst_full_block(bit: int) -> int:
    """
    集合 [0, 2^bit) の XOR 完全グラフの MST 重み和は bit * 2^(bit-1)
    """
    if bit <= 0:
        return 0
    return bit * (1 << (bit - 1))


def solve_interval(l: int, r: int, bit: int) -> int:
    """
    集合 {l, ..., r-1} 上の XOR 完全グラフ MST 重み和。
    前提: 0 <= l < r <= 2^bit
    """
    if r - l <= 1 or bit == 0:
        return 0

    # 全ブロックなら閉形式
    if l == 0 and r == (1 << bit):
        return mst_full_block(bit)

    half = 1 << (bit - 1)

    if r <= half:
        return solve_interval(l, r, bit - 1)
    if l >= half:
        return solve_interval(l - half, r - half, bit - 1)

    # split
    left = solve_interval(l, half, bit - 1)
    right_len = r - half
    right = solve_interval(0, right_len, bit - 1)

    a = l
    b = right_len

    # 追加で 1 本つなぐコスト: 2^(bit-1) + delta
    # delta = min_{x in [a,2^(bit-1)), y in [0,b)} (x xor y)
    # a < b なら重なって delta=0
    if a < b:
        delta = 0
    else:
        delta = minxor_prefix(a, b, bit - 1)

    return left + right + half + delta


def solve_case(N: int, M: int) -> int:
    L = M
    R = M + N
    # M,N <= 2e9 より R < 4e9 < 2^32、bit=32 で被覆
    return solve_interval(L, R, 32)


def main() -> None:
    data = list(map(int, sys.stdin.buffer.read().split()))
    T = data[0]
    out = []
    idx = 1
    for _ in range(T):
        N = data[idx]
        M = data[idx + 1]
        idx += 2
        out.append(str(solve_case(N, M)))
    sys.stdout.write("\n".join(out))

if __name__ == "__main__":
    main()
0