結果

問題 No.1936 Rational Approximation
ユーザー ir5ir5
提出日時 2024-06-16 01:19:05
言語 PyPy3
(7.3.15)
結果
AC  
実行時間 1,003 ms / 2,000 ms
コード長 1,793 bytes
コンパイル時間 391 ms
コンパイル使用メモリ 82,392 KB
実行使用メモリ 90,816 KB
最終ジャッジ日時 2024-06-16 01:19:22
合計ジャッジ時間 15,504 ms
ジャッジサーバーID
(参考情報)
judge2 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 716 ms
90,284 KB
testcase_01 AC 764 ms
90,136 KB
testcase_02 AC 747 ms
90,816 KB
testcase_03 AC 1,003 ms
89,896 KB
testcase_04 AC 930 ms
90,096 KB
testcase_05 AC 972 ms
90,360 KB
testcase_06 AC 953 ms
90,036 KB
testcase_07 AC 955 ms
90,216 KB
testcase_08 AC 942 ms
89,976 KB
testcase_09 AC 970 ms
90,240 KB
testcase_10 AC 952 ms
90,092 KB
testcase_11 AC 978 ms
90,044 KB
testcase_12 AC 912 ms
90,188 KB
testcase_13 AC 971 ms
90,232 KB
testcase_14 AC 142 ms
88,680 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

from fractions import Fraction


def xgcd(a, b):
    x0, y0, x1, y1 = 1, 0, 0, 1
    while b != 0:
        q, a, b = a // b, b, a % b
        x0, x1 = x1, x0 - q * x1
        y0, y1 = y1, y0 - q * y1
    return a, x0, y0


def modinv(a, m):
    g, x, y = xgcd(a, m)
    if g != 1:
        raise Exception('Modular inverse does not exist')
    else:
        return x % m


def solve1(p, q, sqrt):
    invp = modinv(p, q)

    # lower
    lower = Fraction(0, 1)

    def go1(x):
        nonlocal lower
        f = Fraction(int(p * x // q), x)
        lower = max(lower, f)

    for x in range(1, sqrt):
        # go1(x)
        go1(q - x)

    for y in range(1, sqrt):
        # px mod q = y
        go1(y * invp % q)

    return lower


def solve2(p, q, sqrt):
    invp = modinv(p, q)

    # upper
    upper = Fraction(p, 1)

    def go2(x):
        nonlocal upper
        f = Fraction(int(p * x // q + 1), x)
        upper = min(upper, f)

    for x in range(1, sqrt):
        go2(x)
        # go2(q - x)

    for y in range(q - sqrt + 1, q):
        # px mod q = y
        go2(y * invp % q)

    return upper


def main():
    p, q = list(map(int, input().split()))
    sqrt = min(q, 5 * 10 ** 5)

    lower = solve1(p, q, sqrt)
    upper = solve2(p, q, sqrt)

    import sys
    print(upper, file=sys.stderr)
    print(lower, file=sys.stderr)

    print(upper.numerator + upper.denominator + lower.numerator + lower.denominator)


def stress():
    q = 100000
    import numpy as np
    for _ in range(100):
        p = np.random.randint(1, q)
        import math
        if math.gcd(p, q) != 1:
            continue

        lower_b = solve1(p, q, q)
        lower = solve1(p, q, int(np.sqrt(q)) + 10)

        print(p, q, lower_b, lower)
        assert lower_b == lower


# stress()
main()
0