from bisect import bisect N, M = map(int, input().split()) A = list(map(int, input().split())) B = list(map(int, input().split())) C = [[] for _ in range(N+1)] for i in range(N): C[i+1].append(A[i]) for b in B: ind = bisect(A, b) C[ind].append(b) for i in range(N): C[i].append(A[i]) ans = 0 for i in range(1, N): if len(C[i]) == 2: continue L = C[i][-1]-C[i][0] v = L for j in range(len(C[i])-1): v = min(v, L-(C[i][j+1]-C[i][j])) ans += v if len(C[0]) > 1: ans += C[0][-1]-C[0][0] if len(C[N]) > 1: ans += C[N][-1]-C[N][0] print(ans)