n,k = map(int,input().split()) p = list(map(int,input().split())) p.sort if p.count(p[-k]) >1: print(sum([i > p[-k] for i in p])) else: print(p[-k])