結果

問題 No.2438 Double Least Square
ユーザー gew1fw
提出日時 2025-06-12 14:34:40
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 1,971 bytes
コンパイル時間 138 ms
コンパイル使用メモリ 82,204 KB
実行使用メモリ 76,596 KB
最終ジャッジ日時 2025-06-12 14:35:02
合計ジャッジ時間 2,800 ms
ジャッジサーバーID
(参考情報)
judge5 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 16 WA * 14
権限があれば一括ダウンロードができます

ソースコード

diff #

def main():
    import sys
    input = sys.stdin.read().split()
    idx = 0
    N = int(input[idx])
    idx += 1
    H = int(input[idx])
    idx += 1
    points = []
    for _ in range(N):
        x = int(input[idx])
        y = int(input[idx + 1])
        points.append((x, y))
        idx += 2

    a1 = 0.0
    a2 = 0.0
    tolerance = 1e-8
    max_iterations = 10000
    converged = False

    for _ in range(max_iterations):
        # Assign points to S_f and S_g
        S_f = []
        S_g = []
        for x, y in points:
            error_f = (y - (a1 * x + H)) ** 2
            error_g = (y - (a2 * x)) ** 2
            if error_f < error_g:
                S_f.append((x, y))
            else:
                S_g.append((x, y))

        # Compute a1_new
        sum_x2_f = sum(x * x for x, y in S_f)
        sum_xy_f = sum(x * y for x, y in S_f)
        sum_x_f = sum(x for x, y in S_f)
        if sum_x2_f == 0:
            a1_new = a1
        else:
            numerator = sum_xy_f - H * sum_x_f
            a1_new = numerator / sum_x2_f

        # Compute a2_new
        sum_x2_g = sum(x * x for x, y in S_g)
        sum_xy_g = sum(x * y for x, y in S_g)
        if sum_x2_g == 0:
            a2_new = a2
        else:
            a2_new = sum_xy_g / sum_x2_g

        # Check for convergence
        change_a1 = abs(a1_new - a1)
        change_a2 = abs(a2_new - a2)
        rel_change_a1 = change_a1 / (abs(a1) + 1e-8)
        rel_change_a2 = change_a2 / (abs(a2) + 1e-8)

        if rel_change_a1 <= tolerance and rel_change_a2 <= tolerance:
            a1 = a1_new
            a2 = a2_new
            converged = True
            break

        a1 = a1_new
        a2 = a2_new

    if not converged:
        pass

    # Compute L
    L = 0.0
    for x, y in points:
        error_f = (y - (a1 * x + H)) ** 2
        error_g = (y - (a2 * x)) ** 2
        L += min(error_f, error_g)

    print("{0:.15f}".format(L))

if __name__ == "__main__":
    main()
0