結果

問題 No.2438 Double Least Square
ユーザー gew1fw
提出日時 2025-06-12 18:05:24
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 1,927 bytes
コンパイル時間 192 ms
コンパイル使用メモリ 82,080 KB
実行使用メモリ 69,920 KB
最終ジャッジ日時 2025-06-12 18:07:11
合計ジャッジ時間 2,483 ms
ジャッジサーバーID
(参考情報)
judge1 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 16 WA * 14
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def main():
    input = sys.stdin.read().split()
    idx = 0
    N = int(input[idx])
    idx += 1
    H = int(input[idx])
    idx += 1
    x = []
    y = []
    for _ in range(N):
        xi = int(input[idx])
        yi = int(input[idx + 1])
        x.append(xi)
        y.append(yi)
        idx += 2

    a1 = 0.0
    a2 = 0.0
    eps = 1e-12
    max_iter = 100
    for _ in range(max_iter):
        S = []
        T = []
        for i in range(N):
            xi_val = x[i]
            yi_val = y[i]
            f = a1 * xi_val + H
            g = a2 * xi_val
            error_f = (yi_val - f) ** 2
            error_g = (yi_val - g) ** 2
            if error_f < error_g - eps:
                S.append(i)
            else:
                T.append(i)
        # Compute new_a1
        sum_xi_yi_minus_H = 0.0
        sum_xi_sq_S = 0.0
        for i in S:
            xi_val = x[i]
            yi_val = y[i]
            sum_xi_yi_minus_H += xi_val * (yi_val - H)
            sum_xi_sq_S += xi_val * xi_val
        new_a1 = a1
        if sum_xi_sq_S != 0:
            new_a1 = sum_xi_yi_minus_H / sum_xi_sq_S

        # Compute new_a2
        sum_xi_yi = 0.0
        sum_xi_sq_T = 0.0
        for i in T:
            xi_val = x[i]
            yi_val = y[i]
            sum_xi_yi += xi_val * yi_val
            sum_xi_sq_T += xi_val * xi_val
        new_a2 = a2
        if sum_xi_sq_T != 0:
            new_a2 = sum_xi_yi / sum_xi_sq_T

        # Check convergence
        if abs(new_a1 - a1) < eps and abs(new_a2 - a2) < eps:
            break
        a1 = new_a1
        a2 = new_a2

    # Compute the sum of minima
    total = 0.0
    for i in range(N):
        xi_val = x[i]
        yi_val = y[i]
        error_f = (yi_val - (a1 * xi_val + H)) ** 2
        error_g = (yi_val - (a2 * xi_val)) ** 2
        total += min(error_f, error_g)
    print("{0:.20f}".format(total))

if __name__ == '__main__':
    main()
0