import sys n = int(sys.stdin.readline()) Y = list(map(int, sys.stdin.readline().split())) Y.sort() prefix_sum = [0] * (n + 1) for i in range(n): prefix_sum[i + 1] = prefix_sum[i] + Y[i] INF = float('inf') dp = [INF] * (n + 1) dp[0] = 0 for i in range(1, n + 1): # Check for group of size 2 if i >= 2: left = i - 2 right = i - 1 m = (right - left) // 2 + left sum_left = (m - left + 1) * Y[m] - (prefix_sum[m + 1] - prefix_sum[left]) sum_right = (prefix_sum[right + 1] - prefix_sum[m + 1]) - (right - m) * Y[m] cost = sum_left + sum_right dp[i] = min(dp[i], dp[left] + cost) # Check for group of size 3 if i >= 3: left = i - 3 right = i - 1 m = (right - left) // 2 + left sum_left = (m - left + 1) * Y[m] - (prefix_sum[m + 1] - prefix_sum[left]) sum_right = (prefix_sum[right + 1] - prefix_sum[m + 1]) - (right - m) * Y[m] cost = sum_left + sum_right dp[i] = min(dp[i], dp[left] + cost) print(int(dp[n]))