n,m=map(int,input().split()) a=list(map(int,input().split())) a.sort() c=[0] for i in range(n-1): c+=[c[-1]+(a[i+1]-a[i])**2] m-=1 print(min(c[i+m]-c[i] for i in range(n-m)))