結果

問題 No.913 木の燃やし方
ユーザー gew1fw
提出日時 2025-06-12 14:00:06
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 2,237 bytes
コンパイル時間 134 ms
コンパイル使用メモリ 82,676 KB
実行使用メモリ 121,508 KB
最終ジャッジ日時 2025-06-12 14:01:03
合計ジャッジ時間 13,179 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 3
other AC * 13 WA * 21
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys

def main():
    sys.setrecursionlimit(1 << 25)
    N, *rest = map(int, sys.stdin.read().split())
    A = rest[:N]
    prefix = [0] * (N + 1)
    for i in range(N):
        prefix[i+1] = prefix[i] + A[i]
    
    # Precompute left optimal for each i (i is right end)
    left_opt = list(range(N))
    for i in range(N):
        # Ternary search on j in [0, i]
        low = 0
        high = i
        best_j = i
        best_val = A[i] + 1
        while high - low > 3:
            m1 = low + (high - low) // 3
            m2 = high - (high - low) // 3
            k1 = i - m1 + 1
            sum1 = prefix[i+1] - prefix[m1]
            val1 = sum1 + k1 * k1
            k2 = i - m2 + 1
            sum2 = prefix[i+1] - prefix[m2]
            val2 = sum2 + k2 * k2
            if val1 < val2:
                high = m2
            else:
                low = m1
        for j in range(low, high+1):
            k = i - j + 1
            current = (prefix[i+1] - prefix[j]) + k * k
            if current < best_val:
                best_val = current
                best_j = j
        left_opt[i] = best_j
    
    # Precompute right optimal for each i (i is left end)
    right_opt = list(range(N))
    for i in range(N):
        # Ternary search on j in [i, N-1]
        low = i
        high = N-1
        best_j = i
        best_val = A[i] + 1
        while high - low > 3:
            m1 = low + (high - low) // 3
            m2 = high - (high - low) // 3
            k1 = m1 - i + 1
            sum1 = prefix[m1+1] - prefix[i]
            val1 = sum1 + k1 * k1
            k2 = m2 - i + 1
            sum2 = prefix[m2+1] - prefix[i]
            val2 = sum2 + k2 * k2
            if val1 < val2:
                high = m2
            else:
                low = m1
        for j in range(low, high+1):
            k = j - i + 1
            current = (prefix[j+1] - prefix[i]) + k * k
            if current < best_val:
                best_val = current
                best_j = j
        right_opt[i] = best_j
    
    for i in range(N):
        l = left_opt[i]
        r = right_opt[i]
        total = prefix[r+1] - prefix[l]
        k = r - l + 1
        print(total + k * k)
    
if __name__ == '__main__':
    main()
0