import sys input = sys.stdin.readline N, M = map(int, input().split()) A = [-10**18]+list(map(int, input().split()))+[10**18] B = list(map(int, input().split())) B.reverse() ans = 0 for i in range(N+1): l = [] while B and A[i]<=B[-1]<=A[i+1]: l.append(B.pop()) if len(l)==0: continue v = min(l[-1]-A[i], A[i+1]-l[0]) for j in range(len(l)-1): v = min(v, l[j]-A[i]+A[i+1]-l[j+1]) ans += v print(ans)