import sys input = sys.stdin.readline def iinput(): return int(input()) def sinput(): return input().rstrip() def i0input(): return int(input()) - 1 def linput(): return list(input().split()) def liinput(): return list(map(int, input().split())) def miinput(): return map(int, input().split()) def li0input(): return list(map(lambda x: int(x) - 1, input().split())) def mi0input(): return map(lambda x: int(x) - 1, input().split()) INF = 10**20 MOD = 1000000007 from bisect import bisect_left N, M = liinput() A = liinput() B = liinput() ans = 0 if A[0] > B[0]: ans += A[0] - B[0] if A[-1] < B[-1]: ans += B[-1] - A[-1] if N == 1: print(ans) exit() for a1, a2 in zip(A[:-1], A[1:]): idx1 = bisect_left(B, a1) idx2 = bisect_left(B, a2) m = INF tmp = [a1] for i in range(idx1, idx2): tmp.append(B[i]) tmp.append(a2) for b1, b2 in zip(tmp[:-1], tmp[1:]): m = min(m, b1 - a1 + a2 - b2) ans += m print(ans)