結果

問題 No.2438 Double Least Square
ユーザー gew1fw
提出日時 2025-06-12 18:05:06
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 2,888 bytes
コンパイル時間 142 ms
コンパイル使用メモリ 82,020 KB
実行使用メモリ 185,872 KB
最終ジャッジ日時 2025-06-12 18:06:19
合計ジャッジ時間 4,032 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample -- * 3
other AC * 10 TLE * 1 -- * 19
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def readints():
    return list(map(int, sys.stdin.readline().split()))

def compute_min_sum(a1_init, a2_init, points, H):
    a1 = a1_init
    a2 = a2_init
    prev_assigned_f = None
    for _ in range(100):
        assigned_f = []
        assigned_g = []
        for x, y in points:
            error_f = (y - (a1 * x + H))**2
            error_g = (y - (a2 * x))**2
            if error_f < error_g - 1e-9:
                assigned_f.append((x, y))
            elif error_g < error_f - 1e-9:
                assigned_g.append((x, y))
            else:
                assigned_f.append((x, y))
                assigned_g.append((x, y))
        if prev_assigned_f is not None and (assigned_f == prev_assigned_f or assigned_f == prev_assigned_g):
            break
        prev_assigned_f = assigned_f.copy()
        prev_assigned_g = assigned_g.copy()
        sum_xi_sq_f = sum(x*x for x, y in assigned_f)
        sum_xi_f = sum(x for x, y in assigned_f)
        sum_xiyi_f = sum(x*y for x, y in assigned_f)
        new_a1 = a1
        if sum_xi_sq_f != 0:
            new_a1 = (sum_xiyi_f - H * sum_xi_f) / sum_xi_sq_f
        sum_xi_sq_g = sum(x*x for x, y in assigned_g)
        sum_xiyi_g = sum(x*y for x, y in assigned_g)
        new_a2 = a2
        if sum_xi_sq_g != 0:
            new_a2 = sum_xiyi_g / sum_xi_sq_g
        if abs(new_a1 - a1) < 1e-9 and abs(new_a2 - a2) < 1e-9:
            break
        a1, a2 = new_a1, new_a2
    total = 0.0
    for x, y in points:
        error_f = (y - (a1 * x + H))**2
        error_g = (y - (a2 * x))**2
        total += min(error_f, error_g)
    return total

def main():
    N = int(sys.stdin.readline())
    H = int(sys.stdin.readline())
    points = []
    for _ in range(N):
        x, y = map(int, sys.stdin.readline().split())
        points.append((x, y))
    
    candidates = []
    sum_xi = sum(x for x, y in points)
    sum_xi_sq = sum(x*x for x, y in points)
    sum_xiyi_f = sum(x*y for x, y in points)
    sum_xiyi_g = sum(x*y for x, y in points)
    if sum_xi_sq != 0:
        a1_initial = (sum_xiyi_f - H * sum_xi) / sum_xi_sq
        a2_initial = sum_xiyi_g / sum_xi_sq
    else:
        a1_initial = 0.0
        a2_initial = 0.0
    candidates.append((a1_initial, a2_initial))
    
    for x, y in points:
        if x != 0:
            a2_eq1 = H / x
            candidates.append((0.0, a2_eq1))
            a2_eq2 = (2 * y - H) / x
            candidates.append((0.0, a2_eq2))
    
    candidates.append((0.0, 0.0))
    candidates.append((0.0, 1.0))
    candidates.append((1.0, 0.0))
    candidates.append((-100.0, 100.0))
    
    min_total = float('inf')
    for a1, a2 in candidates:
        current_total = compute_min_sum(a1, a2, points, H)
        if current_total < min_total:
            min_total = current_total
    
    print("{0:.20f}".format(min_total))

if __name__ == "__main__":
    main()
0