結果

問題 No.2438 Double Least Square
ユーザー lam6er
提出日時 2025-04-16 15:59:35
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 3,101 bytes
コンパイル時間 233 ms
コンパイル使用メモリ 81,532 KB
実行使用メモリ 76,780 KB
最終ジャッジ日時 2025-04-16 16:01:26
合計ジャッジ時間 3,135 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 20 WA * 10
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def readints():
    return list(map(int, sys.stdin.readline().split()))

def compute_min_L(points, H, initial_a1, initial_a2, max_iter=200):
    a1 = initial_a1
    a2 = initial_a2
    n = len(points)
    prev_assignment = None
    for _ in range(max_iter):
        assignment = []
        for xi, yi in points:
            error_f = (yi - (a1 * xi + H)) ** 2
            error_g = (yi - (a2 * xi)) ** 2
            assignment.append(error_f < error_g)
        if prev_assignment == assignment:
            break
        prev_assignment = assignment.copy()
        sum_f_xi2 = 0.0
        sum_f_xi_yi_minus_H = 0.0
        sum_g_xi2 = 0.0
        sum_g_xi_yi = 0.0
        for i in range(n):
            xi, yi = points[i]
            if assignment[i]:
                sum_f_xi2 += xi ** 2
                sum_f_xi_yi_minus_H += xi * (yi - H)
            else:
                sum_g_xi2 += xi ** 2
                sum_g_xi_yi += xi * yi
        new_a1 = sum_f_xi_yi_minus_H / sum_f_xi2 if sum_f_xi2 != 0 else 0.0
        new_a2 = sum_g_xi_yi / sum_g_xi2 if sum_g_xi2 != 0 else 0.0
        if abs(new_a1 - a1) < 1e-12 and abs(new_a2 - a2) < 1e-12:
            break
        a1, a2 = new_a1, new_a2
    total = 0.0
    for xi, yi in points:
        error_f = (yi - (a1 * xi + H)) ** 2
        error_g = (yi - (a2 * xi)) ** 2
        total += min(error_f, error_g)
    return total

def main():
    import sys
    input = sys.stdin.read().split()
    ptr = 0
    N = int(input[ptr])
    ptr +=1
    H = int(input[ptr])
    ptr +=1
    points = []
    for _ in range(N):
        x = int(input[ptr])
        y = int(input[ptr+1])
        ptr +=2
        points.append( (x, y) )
    
    initializations = []
    
    sum_f_xi2_all = sum(xi**2 for xi, yi in points)
    sum_f_xi_yi_minus_H_all = sum(xi*(yi - H) for xi, yi in points)
    a1_all = sum_f_xi_yi_minus_H_all / sum_f_xi2_all if sum_f_xi2_all !=0 else 0.0
    sum_g_xi2_all = sum(xi**2 for xi, yi in points)
    sum_g_xi_yi_all = sum(xi*yi for xi, yi in points)
    a2_all = sum_g_xi_yi_all / sum_g_xi2_all if sum_g_xi2_all !=0 else 0.0
    initializations.append( (a1_all, a2_all) )
    
    initializations.append( (0.0, 0.0) )
    
    initializations.append( (a1_all, 0.0) )
    
    initializations.append( (0.0, a2_all) )
    
    sum_f_xi2_all_f = sum(xi**2 for xi, yi in points)
    sum_f_xi_yi_minus_H_all_f = sum(xi*(yi - H) for xi, yi in points)
    a1_all_f = sum_f_xi_yi_minus_H_all_f / sum_f_xi2_all_f if sum_f_xi2_all_f !=0 else 0.0
    sum_g_xi2_all_g = sum(xi**2 for xi, yi in points)
    sum_g_xi_yi_all_g = sum(xi*yi for xi, yi in points)
    a2_all_g = sum_g_xi_yi_all_g / sum_g_xi2_all_g if sum_g_xi2_all_g !=0 else 0.0
    initializations.append( (a1_all_f, a2_all_g) )
    
    min_L = float('inf')
    for init_a1, init_a2 in initializations:
        L = compute_min_L(points, H, init_a1, init_a2)
        if L < min_L:
            min_L = L
    
    print("{0:.20f}".format(min_L).rstrip('0').rstrip('.') if '.' in "{0:.20f}".format(min_L) else "{0:.20f}".format(min_L))

if __name__ == '__main__':
    main()
0