import java.util.*; public class Main { public static void main (String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); long k = sc.nextLong(); int[] arr = new int[n + 1]; HashMap map = new HashMap<>(); for (int i = 1; i <= n; i++) { arr[i] = sc.nextInt(); map.put(arr[i], i); } int count = 0; for (int i = 1; i <= n; i++) { int x = map.get(i); if (i == x) { continue; } count++; arr[x] = arr[i]; map.put(x, arr[i]); } if (k >= count && (k - count) % 2 == 0) { System.out.println("YES"); } else { System.out.println("NO"); } } }