def main(): import sys from collections import defaultdict N, K = map(int, sys.stdin.readline().split()) a = list(map(int, sys.stdin.readline().split())) original = a.copy() sorted_a = sorted(a) # 1-based parent = list(range(N + 1)) # parent[0] is unused def find(u): while parent[u] != u: parent[u] = parent[parent[u]] # Path compression u = parent[u] return u def union(u, v): pu = find(u) pv = find(v) if pu != pv: parent[pv] = pu for i in range(1, N + 1): j = i + K if j <= N: union(i, j) # Check feasibility possible = True for idx, num in enumerate(sorted_a): j = idx + 1 # 1-based original_i = original.index(num) + 1 # 1-based index of num in original array if find(original_i) != find(j): possible = False break if not possible: print(-1) return # Collect groups groups = defaultdict(list) for i in range(1, N + 1): root = find(i) groups[root].append(i) total_inversions = 0 # Function to count inversions using BIT def count_inversions(arr): max_val = max(arr) if arr else 0 size = max_val + 2 tree = [0] * (size) def update(idx): idx += 1 while idx < size: tree[idx] += 1 idx += idx & -idx def query(idx): res = 0 idx += 1 while idx > 0: res += tree[idx] idx -= idx & -idx return res inversions = 0 for i in reversed(range(len(arr))): inversions += query(arr[i] - 1) update(arr[i]) return inversions for group in groups.values(): sorted_positions = sorted(group) current_values = [original[pos - 1] for pos in sorted_positions] sorted_values = sorted(current_values) index_map = {v: i for i, v in enumerate(sorted_values)} mapped = [index_map[v] for v in current_values] total_inversions += count_inversions(mapped) print(total_inversions) if __name__ == "__main__": main()