from scipy.sparse.csgraph import minimum_spanning_tree from scipy.sparse import csr_matrix N, M = map(int, input().split()) A = map(int, input().split()) B = map(int, input().split()) nodes = [] for i, a in enumerate(A): nodes.append((a, i)) for i, b in enumerate(B, N): nodes.append((b, i)) nodes.sort() length = [0] * N frm = list(range(N)) to = [N + M] * N for i in range(N + M - 1): frm.append(nodes[i][1]) to.append(nodes[i + 1][1]) length.append(nodes[i + 1][0] - nodes[i][0]) matr = csr_matrix((length, (frm, to)), shape=(N + M + 1, N + M + 1)) T = minimum_spanning_tree(matr).astype(int) print(T.sum())