n,k=map(int,input().split()) d=list(map(lambda x: x-1, map(int,input().split()))) dif=False cnt=0 for i in range(n): j = d[i] if d[i] != j and d[j] != i: dif = True if i == j: continue cnt+=1 d[j] = d[i] if dif and k > cnt: print("YES") else: print("YES" if cnt == k or (k>cnt and (k-cnt)&1==0) else "NO")