import collections,sys,math,functools,operator,itertools,bisect,heapq,decimal,string,time,random #sys.setrecursionlimit(10**9) #sys.set_int_max_str_digits(0) #n = int(input()) # #alist = [] #s = input() n,m = map(int,input().split()) #for i in range(n): # alist.append(list(map(int,input().split()))) alist = list(map(int,input().split())) alist.sort() ans = 0 for i in range(m-1): ans += (alist[i+1]-alist[i])**2 a = [ans] for i in range(m-1,n-1): ans -= (alist[i-m+1] - alist[i-m+2])**2 ans += (alist[i]-alist[i+1])**2 a.append(ans) print(min(a))