結果

問題 No.2438 Double Least Square
ユーザー lam6er
提出日時 2025-03-31 17:54:41
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,162 bytes
コンパイル時間 347 ms
コンパイル使用メモリ 82,448 KB
実行使用メモリ 76,736 KB
最終ジャッジ日時 2025-03-31 17:56:02
合計ジャッジ時間 3,199 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 21 WA * 9
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def compute_L(points, H, a1_init, a2_init):
    a1 = a1_init
    a2 = a2_init
    for _ in range(100):
        sum_f_num = 0.0
        sum_f_den = 0.0
        sum_g_num = 0.0
        sum_g_den = 0.0
        assigned_f = []
        assigned_g = []
        for x, y in points:
            res_f = (y - (a1 * x + H))**2
            res_g = (y - a2 * x)**2
            if res_f <= res_g:
                assigned_f.append((x, y))
                sum_f_num += (y - H) * x
                sum_f_den += x * x
            else:
                assigned_g.append((x, y))
                sum_g_num += y * x
                sum_g_den += x * x
        new_a1, new_a2 = a1, a2
        if assigned_f:
            new_a1 = sum_f_num / sum_f_den if sum_f_den != 0 else a1
        if assigned_g:
            new_a2 = sum_g_num / sum_g_den if sum_g_den != 0 else a2
        if abs(new_a1 - a1) < 1e-12 and abs(new_a2 - a2) < 1e-12:
            break
        a1, a2 = new_a1, new_a2
    total = 0.0
    for x, y in points:
        res_f = (y - (a1 * x + H))**2
        res_g = (y - a2 * x)**2
        total += min(res_f, res_g)
    return total

def main():
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr += 1
    H = int(input[ptr])
    ptr += 1
    points = []
    sum_x2_all = 0.0
    sum_f_num_initial = 0.0
    sum_g_num_initial = 0.0
    for _ in range(N):
        x = int(input[ptr])
        y = int(input[ptr + 1])
        ptr += 2
        points.append((x, y))
        sum_x2_all += x * x
        sum_f_num_initial += (y - H) * x
        sum_g_num_initial += y * x

    sum_x2_all = max(sum_x2_all, 1e-9)  # Prevent division by zero

    a1_init_all_f = sum_f_num_initial / sum_x2_all
    a2_init_all_g = sum_g_num_initial / sum_x2_all

    initial_params = [
        (0.0, 0.0),
        (a1_init_all_f, a2_init_all_g),
        (a1_init_all_f, 0.0),
        (0.0, a2_init_all_g),
    ]

    min_L = float('inf')
    for a1_init, a2_init in initial_params:
        L = compute_L(points, H, a1_init, a2_init)
        if L < min_L:
            min_L = L

    print("{0:.12f}".format(min_L))

if __name__ == "__main__":
    main()
0