結果

問題 No.3398 Accuracy of Integer Division Approximate Function 2
コンテスト
ユーザー 👑 Mizar
提出日時 2025-11-02 15:07:01
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 181 ms / 2,000 ms
コード長 7,719 bytes
記録
コンパイル時間 172 ms
コンパイル使用メモリ 82,512 KB
実行使用メモリ 77,504 KB
最終ジャッジ日時 2025-12-04 23:30:16
合計ジャッジ時間 2,979 ms
ジャッジサーバーID
(参考情報)
judge1 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 20
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

# -*- coding: utf-8 -*-
"""
Max Weighted Floor (mwf) を用いて x_min(D, A, B, K) を求める。
"""


def mwf(n: int, m: int, a: int, b: int, c: int, d: int) -> int:
    """
    Max Weighted Floor (mwf) の非再帰実装。
      mwf(n,m,a,b,c,d) = max_{0 <= x < n} a*x + b*floor((c*x + d)/m)

    前提:
      - n > 0, m > 0

    計算量/メモリ:
      - 時間: O(log m)(ユークリッド互除法的再帰による構造縮約)
      - 追加メモリ: O(1)
    """
    assert n > 0 and m > 0
    sum_acc: int = int(0)  # 現在の累積和
    max_acc: int = b * (d // m)  # 現在の累積max. 初期値は x = 0 のときの値
    while True:
        # c, d をそれぞれ 正の整数 m で割った剰余にする正規化
        # Python の divmod は Flooring Division に基づくので、除数 m が正であるため
        # 元の c, d が負でも正規化後の剰余は 0 <= c < m, 0 <= d < m が保証される
        # 負の整数 % 正の整数 = 負の整数 となる言語(C++/Java など)では移植時に注意
        q, c = divmod(c, m)  # q = c // m, c = c % m
        a += b * q  # c の商分を a に足す
        q, d = divmod(d, m)  # q = d // m, d = d % m
        sum_acc += b * q  # d の商分を s に足す
        assert 0 <= c < m and 0 <= d < m
        # 現在の小問題における x = 0 のときの値 s を r に反映
        max_acc = max(max_acc, sum_acc)
        # 0 ≤ x < n における y = floor((c*x+d)/m) の最大値を計算
        y_max = (c * (n - 1) + d) // m
        # y_max == 0 の場合は右端を考慮して終了
        if y_max == 0:
            return max(max_acc, sum_acc + a * (n - 1))
        # y_max >= 1 の場合は再帰的に解く
        # c > 0, n > 1 のときにのみ y_max >= 1 となりうる
        if a >= 0:
            # a >= 0 の場合
            max_acc = max(max_acc, sum_acc + a * (n - 1) + b * y_max)
        else:
            # a < 0 の場合
            sum_acc += a + b
        # 小問題へのパラメータ変換
        n, m, a, b, c, d = y_max, c, b, a, m, (m - d - 1)


def mwf_leq(z: int, n: int, m: int, a: int, b: int, c: int, d: int) -> bool:
    """
    Max Weighted Floor (mwf) の非再帰実装。
      mwf(n,m,a,b,c,d) = max_{0 <= x < n} a*x + b*floor((c*x + d)/m)

    返り値: mwf(n,m,a,b,c,d) <= z なら True、そうでなければ False を返す。

    前提:
      - n > 0, m > 0

    計算量/メモリ:
      - 時間: O(log m)(ユークリッド互除法的再帰による構造縮約)
      - 追加メモリ: O(1)
    """
    assert n > 0 and m > 0
    sum_acc: int = -z  # 現在の累積和
    while True:
        # c, d をそれぞれ 正の整数 m で割った剰余にする正規化
        # Python の divmod は Flooring Division に基づくので、除数 m が正であるため
        # 元の c, d が負でも正規化後の剰余は 0 <= c < m, 0 <= d < m が保証される
        # 負の整数 % 正の整数 = 負の整数 となる言語(C++/Java など)では移植時に注意
        q, c = divmod(c, m)  # q = c // m, c = c % m
        a += b * q  # c の商分を a に足す
        q, d = divmod(d, m)  # q = d // m, d = d % m
        sum_acc += b * q  # d の商分を s に足す
        assert 0 <= c < m and 0 <= d < m
        # 左端が z を超える場合は早期終了
        if sum_acc > 0:
            return False
        # 0 ≤ x < n における y = floor((c*x+d)/m) の最大値を計算
        y_max = (c * (n - 1) + d) // m
        # y_max == 0 の場合は右端が z を超えるか判定して終了
        if y_max == 0:
            return (sum_acc + a * (n - 1)) <= 0
        # どうしても z 以下な場合は早期終了
        if sum_acc + max(0, a * (n - 1)) + max(0, b * y_max) <= 0:
            return True
        # y_max >= 1 の場合は再帰的に解く
        # c > 0, n > 1 のときにのみ y_max >= 1 となりうる
        if a >= 0:
            # a >= 0 の場合 : 右端が z を超える場合は早期終了
            if (sum_acc + a * (n - 1) + b * y_max) > 0:
                return False
        else:
            # a < 0 の場合
            sum_acc += a + b
        # 小問題へのパラメータ変換
        n, m, a, b, c, d = y_max, c, b, a, m, (m - d - 1)


def mwf_lr(L: int, R: int, m: int, a: int, b: int, c: int, d: int) -> int:
    """
    max_{L <= x < R} a*x + b*floor((c*x + d)/m) を計算して返す。

    既存の mwf(n, m, a, b, c, d)(0 <= x < n)を用いる。
    前提: L < R, m > 0
    計算量: 既存の mwf に準ずる(O(log m) スタイルの再帰)。
    """
    assert L < R and m > 0
    n = R - L
    q, d = divmod(c * L + d, m)
    return a * L + b * q + mwf(n, m, a, b, c, d)


def mwf_lr_leq(z: int, L: int, R: int, m: int, a: int, b: int, c: int, d: int) -> bool:
    """
    max_{L <= x < R} a*x + b*floor((c*x + d)/m) <= z なら true、そうでなければ false を返す。

    既存の mwf_ge(n, m, a, b, c, d)(0 <= x < n)を用いる。
    前提: L < R, m > 0
    計算量: 既存の mwf に準ずる(O(log m) スタイルの再帰)。
    """
    assert L < R and m > 0
    n = R - L
    q, d = divmod(c * L + d, m)
    return mwf_leq(z - a * L - b * q, n, m, a, b, c, d)


def compute_xmin_leq(D: int, A: int, B: int, K: int) -> int:
    """
    x_min(D, A, B, K) を半開区間二分探索 [0, A'BK+2) で求めます(解なしは -1)。

    前提:
      * D > 0, A > 0, B > 0, K >= 0(整数)
    手順概要:
      1) 既約化: g = gcd(D, A), D' = D/g, A' = A/g
      2) (M', R') = divmod(A' * B, D')(A'B = D'*M' + R')
      3) 閾値 T_Δ = B*K を設定
      4) E(u) = B*u - M'*floor(D'u / A')
      5) F(N) = max_{0 <= u < N} E(u) を mwf で評価(N > 0, m = A' > 0)
      6) 区間 [0, A'BK+2) で述語 [F(u) <= T_Δ] を二分探索し、
         F(u) <= T_Δ となる最大の u 、つまり T_Δ < E(u) となる最小の u を特定。x = D*u を返す。

    備考:
      * R' = 0 かつ D'K + 1 >= A' のときは解が存在しないため -1 を返します。
      * 解が存在する場合、 u_min は必ず [0, A'BK+2) の範囲に存在します。
    """
    import math 
    assert D > 0 and A > 0 and B > 0 and K >= 0
    gcd_DA = math.gcd(D, A)
    Dred, Ared = D // gcd_DA, A // gcd_DA
    Mred, Rred = divmod(Ared * B, Dred)
    Tdelta = B * K
    # 解なしをパラメータを用いて判定
    if Rred == 0 and Dred * K + 1 >= Ared:
        return -1
    # [0, hi) の半開区間、緩い上界 A'BK+1 を包括する hi = A'BK+2 を設定
    lo, hi = 0, Ared * B * K + 2
    # F(hi) > T の不変条件を確認
    assert not mwf_lr_leq(Tdelta, lo, hi, Ared, B, -Mred, Dred, 0)
    # F(lo) <= T, F(hi) > T の不変条件で u_min を二分探索
    while lo + 1 < hi:
        mid = (lo + hi) // 2
        if mwf_lr_leq(Tdelta, lo, mid, Ared, B, -Mred, Dred, 0):
            lo = mid
        else:
            hi = mid
    # lo = u_min, hi = lo + 1
    return D * lo


def delta_val(D: int, A: int, B: int, x: int) -> int:
    """検算用 Δ(D,A,B,x)。"""
    P = x // D
    M = (A * B) // D
    Q = ((x // A) * M) // B
    return P - Q


def solve():
    """
    入力を受け取り、各ケースについて x_min(D, A, B, K) を求めて出力します。
    """
    import sys
    input = sys.stdin.readline

    T = int(input())
    for _ in range(T):
        D, A, B, K = map(int, input().split())
        assert 1 <= D
        assert 1 <= A
        assert 1 <= B
        assert 0 <= K
        ans = compute_xmin_leq(D, A, B, K)
        print(ans)


if __name__ == '__main__':
    solve()
0