n, m = map(int, input().split()) a = list(map(int, input().split())) b = list(map(int, input().split())) a = [(i, 0) for i in a] b = [(i, 1) for i in b] l = sorted(a + b) que = [] INF = 10 ** 18 cur = -INF ans = 0 for i, c in l: if c == 0: if not que: cur = i; continue if len(que) == 1: ans += min(i - min(que), max(que) - cur); cur = i; que = []; continue cost = min(max(que) - cur, i - min(que)) mal, mir = -INF, que[0] for j in range(len(que)): mal = que[j] mir = que[j + 1] if j + 1 < len(que) else -INF cost = min(cost, (mal - cur) + (i - mir)) ans += cost cur = i que = [] else: que.append(i) if que: ans += max(que) - cur; que = [] print(ans)