n,k = map(int,input().split()) p = list(map(int,input().split())) p.sort() if p.count(p[-k]) >1: print(sum([i > p[-k] for i in p])) else: print(k)