n,k=map(int,input().split()) d=list(map(lambda x: x-1, map(int,input().split()))) dif=False cnt=0 for i in range(n): j = d[i] if d[i] == i: continue if d[j] == i: cnt+=1 d[i] = i d[j] = j else: cnt+=1 dif = True if dif and k >= cnt: print("YES") else: print("YES" if cnt == k or (k>cnt and (k-cnt)&1==0) else "NO")