結果

問題 No.3398 Accuracy of Integer Division Approximate Function 2
コンテスト
ユーザー NyaanNyaan
提出日時 2025-12-05 01:11:17
言語 Python3
(3.13.1 + numpy 2.2.1 + scipy 1.14.1)
結果
TLE  
実行時間 -
コード長 5,873 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 1,505 ms
コンパイル使用メモリ 12,672 KB
実行使用メモリ 18,848 KB
最終ジャッジ日時 2025-12-05 01:11:24
合計ジャッジ時間 5,486 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other TLE * 1 -- * 19
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

"""
#include "template/template.hpp"
//
#include "math/bigint-all.hpp"
using namespace Nyaan;

bigint floor_sum(bigint n, bigint m, bigint a, bigint b) {
  bigint ans = 0;
  if (a < 0) {
    bigint a2 = a % m + (a % m < 0 ? m : 0);
    ans -= 1ULL * n * (n - 1) / 2 * ((a2 - a) / m);
    a = a2;
  }
  if (b < 0) {
    bigint b2 = b % m + (b % m < 0 ? m : 0);
    ans -= 1ULL * n * ((b2 - b) / m);
    b = b2;
  }
  if (a >= m) {
    ans += (n - 1) * n * (a / m) / 2;
    a %= m;
  }
  if (b >= m) {
    ans += n * (b / m);
    b %= m;
  }
  bigint y_max = (a * n + b) / m, x_max = (y_max * m - b);
  if (y_max == 0) return ans;
  ans += (n - (x_max + a - 1) / a) * y_max;
  ans += floor_sum(y_max, a, m, (a - x_max % a) % a);
  return ans;
}

bigint fs(bigint nL, bigint nR, bigint m, bigint a, bigint b) {
  return floor_sum(nR, m, a, b) - floor_sum(nL, m, a, b);
}

bigint between(bigint N, bigint a1, bigint b1, bigint m1, bigint a2, bigint b2,
               bigint m2) {
  auto on = [&](bigint x) { return (a1 * x + b1) * m2 > (a2 * x + b2) * m1; };
  bigint nL = 0, nR = N;
  if (!on(nL)) {
    bigint ng = nL, ok = nR;
    while (ng + 1 < ok) {
      bigint x = (ng + ok) / 2;
      (on(x) ? ok : ng) = x;
    }
    nL = ok;
  }
  if (!on(nR)) {
    bigint ok = nL, ng = nR;
    while (ok + 1 < ng) {
      bigint x = (ok + ng) / 2;
      (on(x) ? ok : ng) = x;
    }
    nR = ng;
  }
  return fs(nL, nR, m1, a1, b1) - fs(nL, nR, m2, a2, b2);
}

bigint calc(bigint D, bigint A, bigint B, bigint K) {
  K += 1;
  bigint C = A * B / D;
  if (C == 0) return K * D;

  struct Line {
    bigint a, b, m;
  };
  Line L1{D, 0, A};
  Line L2{D, -A, A};
  Line L3{B, B * (-K + 1) - 1, C};
  Line L4{B, -B * K - 1, C};

  bigint INF = Power(bigint(10), 50);
  bigint ng = 0, ok = INF;
  while (ng + 1 < ok) {
    bigint N = (ng + ok) / 2;
    bigint n1 = between(N + 1, L1.a, L1.b, L1.m, L4.a, L4.b, L4.m);
    bigint n2 = between(N + 1, L1.a, L1.b, L1.m, L3.a, L3.b, L3.m);
    bigint n3 = between(N + 1, L2.a, L2.b, L2.m, L4.a, L4.b, L4.m);
    bigint n4 = between(N + 1, L2.a, L2.b, L2.m, L3.a, L3.b, L3.m);
    bigint num = n1 - n2 - n3 + n4;
    (num > 0 ? ok : ng) = N;
  }
  return ok == INF ? -1 : ok * D;
}

void q() {
  bigint D, A, B, K;
  in(D, A, B, K);
  out(calc(D, A, B, K));
}

void Nyaan::solve() {
  int t = 1;
  in(t);
  while (t--) q();
}
"""

# translated by ChatGPT 5.1

import sys


def floor_sum(n: int, m: int, a: int, b: int) -> int:
    """
    sum_{i=0}^{n-1} floor((a*i + b) / m)
    C++ 版をそのまま Python 向けにポートしたもの。
    """
    ans = 0

    # a, b の負を処理(Python の % は常に 0..m-1 なので少しだけ簡略化)
    if a < 0:
        a2 = a % m
        ans -= n * (n - 1) // 2 * ((a2 - a) // m)
        a = a2
    if b < 0:
        b2 = b % m
        ans -= n * ((b2 - b) // m)
        b = b2

    if a >= m:
        ans += (n - 1) * n * (a // m) // 2
        a %= m
    if b >= m:
        ans += n * (b // m)
        b %= m

    y_max = (a * n + b) // m
    x_max = y_max * m - b
    if y_max == 0:
        return ans

    # ここから再帰
    ans += (n - (x_max + a - 1) // a) * y_max
    ans += floor_sum(y_max, a, m, (a - x_max % a) % a)
    return ans


def fs(nL: int, nR: int, m: int, a: int, b: int) -> int:
    """
    sum_{i=nL}^{nR-1} floor((a*i + b) / m)
    """
    return floor_sum(nR, m, a, b) - floor_sum(nL, m, a, b)


def between(N: int,
            a1: int, b1: int, m1: int,
            a2: int, b2: int, m2: int) -> int:
    """
    C++ の between をそのまま移植。
    [0, N] の範囲で on(x) が true になる区間を二分探索し、
    そこでの fs(...) の差を返す。
    """

    def on(x: int) -> bool:
        # (a1*x + b1)/m1 > (a2*x + b2)/m2  を交差乗算で比較
        return (a1 * x + b1) * m2 > (a2 * x + b2) * m1

    nL, nR = 0, N

    # 左端側の「最初に true になる点」を探す
    if not on(nL):
        ng, ok = nL, nR
        while ng + 1 < ok:
            x = (ng + ok) // 2
            if on(x):
                ok = x
            else:
                ng = x
        nL = ok

    # 右端側の「最後に true になる点」を探す
    if not on(nR):
        ok, ng = nL, nR
        while ok + 1 < ng:
            x = (ok + ng) // 2
            if on(x):
                ok = x
            else:
                ng = x
        nR = ng

    return fs(nL, nR, m1, a1, b1) - fs(nL, nR, m2, a2, b2)


def calc(D: int, A: int, B: int, K: int) -> int:
    K += 1
    C = A * B // D
    if C == 0:
        return K * D

    # C++ の struct Line { bigint a, b, m; };
    class Line:
        __slots__ = ("a", "b", "m")

        def __init__(self, a: int, b: int, m: int) -> None:
            self.a = a
            self.b = b
            self.m = m

    L1 = Line(D, 0, A)
    L2 = Line(D, -A, A)
    L3 = Line(B, B * (-K + 1) - 1, C)
    L4 = Line(B, -B * K - 1, C)

    INF = 10 ** 50
    ng, ok = 0, INF

    # num > 0 となる最小の N を二分探索
    while ng + 1 < ok:
        N = (ng + ok) // 2
        n1 = between(N + 1, L1.a, L1.b, L1.m, L4.a, L4.b, L4.m)
        n2 = between(N + 1, L1.a, L1.b, L1.m, L3.a, L3.b, L3.m)
        n3 = between(N + 1, L2.a, L2.b, L2.m, L4.a, L4.b, L4.m)
        n4 = between(N + 1, L2.a, L2.b, L2.m, L3.a, L3.b, L3.m)
        num = n1 - n2 - n3 + n4

        if num > 0:
            ok = N
        else:
            ng = N

    return -1 if ok == INF else ok * D


def main() -> None:
    it = iter(map(int, sys.stdin.read().split()))
    t = next(it, 1)  # 1 行目に t がなければ t=1 とみなす
    out_lines = []
    for _ in range(t):
        D = next(it)
        A = next(it)
        B = next(it)
        K = next(it)
        out_lines.append(str(calc(D, A, B, K)))
    sys.stdout.write("\n".join(out_lines))


if __name__ == "__main__":
    main()
0