def main(): n,m = map(int,input().split()) a = list(map(int,input().split())) a.sort() s = 0 for i in range(m - 1): s += (a[i + 1] - a[i]) ** 2 smin = s for i in range(n-m): s -= (a[i + 1] - a[i]) ** 2 s += (a[i + m] - a[i + m - 1]) ** 2 smin = min(smin,s) print(smin) if __name__ == "__main__": main()