from collections import defaultdict import sys import bisect input = sys.stdin.read def main(): data = input().split() N = int(data[0]) Q = int(data[1]) A = list(map(int, data[2:N+2])) B = list(map(int, data[N+2:])) value_to_indices = defaultdict(list) for index, value in enumerate(A, start=1): value_to_indices[value].append(index) current_position = 1 total_cost = 0 def find_nearest_index(indices, current_position): pos = bisect.bisect_left(indices, current_position) if pos == 0: return indices[0] if pos == len(indices): return indices[-1] before = indices[pos - 1] after = indices[pos] if abs(before - current_position) <= abs(after - current_position): return before else: return after for target in B: possible_indices = value_to_indices[target] best_index = find_nearest_index(possible_indices, current_position) total_cost += abs(current_position - best_index) current_position = best_index print(total_cost) if __name__ == "__main__": main()