class BalancingTree: def __init__(self, n): self.N = n self.root = self.node(1< v: prev = nd.value while True: if v < nd.value: prev = nd.value if nd.left: nd = nd.left else: return prev - 1 else: if nd.right: nd = nd.right else: return prev - 1 @property def max(self): return self.find_l((1< bool: return self.find_r(v - 1) == v class node: def __init__(self, v, p): self.value = v self.pivot = p self.left = None self.right = None def debug(self): def debug_info(nd_): return (nd_.value - 1, nd_.pivot - 1, nd_.left.value - 1 if nd_.left else -1, nd_.right.value - 1 if nd_.right else -1) def debug_node(nd): re = [] if nd.left: re += debug_node(nd.left) if nd.value: re.append(debug_info(nd)) if nd.right: re += debug_node(nd.right) return re print("Debug - root =", self.root.value - 1, debug_node(self.root)[:50]) def debug_list(self): def debug_node(nd): re = [] if nd.left: re += debug_node(nd.left) if nd.value: re.append(nd.value - 1) if nd.right: re += debug_node(nd.right) return re return debug_node(self.root)[:-1] def diff_list(List): rev = [] for i in range(len(List)): if i + 1 < (len(List)): tmp = List[i+1] - List[i] rev.append(tmp) return rev N,M = map(int,input().split()) from collections import defaultdict,deque A = list(map(int,input().split())) INF = float('inf') A.append(0) A.sort() adic = defaultdict(int) for index in range(len(A)): adic[A[index]] = index B = list(map(int,input().split())) B.sort() store = defaultdict(list) BT = BalancingTree(32) for a in A: BT.add(a) for index,b in enumerate(B): l = BT.find_l(b) store[adic[l]].append(b) ans = 0 for i in range(len(A)): if i == len(A) -1: store[i].append(A[i]) store[i].sort() if len(store[i]) > 1: ans += store[i][-1] - store[i][0] elif i == 0: store[i].append(A[i+1]) store[i].sort() if len(store[i]) > 1: ans += store[i][-1] - store[i][0] else: store[i].append(A[i]) store[i].append(A[i+1]) store[i].sort() if len(store[i]) > 2: maxdis = max(diff_list(store[i])) ans += (store[i][-1] - store[i][0]) - maxdis print(ans)