結果

問題 No.3444 Interval Xor MST
コンテスト
ユーザー 👑 potato167
提出日時 2025-12-28 05:13:18
言語 PyPy3
(7.3.17)
結果
AC  
実行時間 1,637 ms / 2,000 ms
コード長 2,917 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 3,463 ms
コンパイル使用メモリ 81,904 KB
実行使用メモリ 107,556 KB
最終ジャッジ日時 2026-02-06 20:51:15
合計ジャッジ時間 12,741 ms
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 7
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys

def minxor_prefix(a: int, b: int, bit: int) -> int:
    """
    b > 0.
    0 <= y < b の範囲で (a xor y) を最小化する値を返す。
    0 <= a < 2^bit を仮定。bit は扱う下位ビット数。
    """
    B = b - 1
    INF = 10**30

    dp_free = INF   # すでに B より小さいことが確定(以降は自由)
    dp_tight = 0    # ここまで B と一致(以降も上限に制約)

    for i in range(bit - 1, -1, -1):
        w = 1 << i
        abit = (a >> i) & 1
        bbit = (B >> i) & 1

        ndp_free = INF
        ndp_tight = INF

        # free からは ybit を 0/1 どちらでも
        if dp_free != INF:
            ndp_free = min(ndp_free, dp_free + (abit * w))           # ybit=0
            ndp_free = min(ndp_free, dp_free + ((abit ^ 1) * w))     # ybit=1

        # tight からは ybit <= bbit
        if bbit == 0:
            # ybit=0 のみ、tight のまま
            ndp_tight = dp_tight + (abit * w)
        else:
            # ybit=0 => free へ
            ndp_free = min(ndp_free, dp_tight + (abit * w))
            # ybit=1 => tight 維持
            ndp_tight = min(ndp_tight, dp_tight + ((abit ^ 1) * w))

        dp_free, dp_tight = ndp_free, ndp_tight

    return dp_free if dp_free < dp_tight else dp_tight


def solve_interval(l: int, r: int, bit: int = 32) -> int:
    """
    集合 {l, l+1, ..., r-1} 上の完全グラフで、辺重みが XOR のときの MST 重み和。
    0 <= l < r <= 2^bit を前提に、上位 bit から分割して計算する。
    """
    if r - l <= 1 or bit == 0:
        return 0

    # 全ブロック [0, 2^bit) のときは閉形式で返す(重要)
    if l == 0 and r == (1 << bit):
        return bit * (1 << (bit - 1))

    half = 1 << (bit - 1)

    # 下半分に完全に収まる
    if r <= half:
        return solve_interval(l, r, bit - 1)

    # 上半分に完全に収まる(half を引いて下へ詰める)
    if l >= half:
        return solve_interval(l - half, r - half, bit - 1)

    # 跨ぐ:左 [l, half), 右 [0, r-half)
    left = solve_interval(l, half, bit - 1)
    right_len = r - half
    right = solve_interval(0, right_len, bit - 1)

    a = l
    b = right_len
    if a < b:
        delta = 0  # 重なりがあるので同じ値を選べて XOR 0
    else:
        # min_{0<=y<b} (a xor y)
        delta = minxor_prefix(a, b, bit - 1)

    return left + right + half + delta


def solve_case(N: int, M: int) -> int:
    L = M
    R = M + N
    # 最大でも < 2^32 なので bit=32 で覆える
    return solve_interval(L, R, 32)


def main():
    data = list(map(int, sys.stdin.buffer.read().split()))
    T = data[0]
    out = []
    idx = 1
    for _ in range(T):
        N = data[idx]; M = data[idx + 1]
        idx += 2
        out.append(str(solve_case(N, M)))
    sys.stdout.write("\n".join(out))

if __name__ == "__main__":
    main()
0