n,k = map(int,input().split()) p = list(map(int,input().split())) p.sort k1=k-1 if p.count(p[-k1]) >1: print(sum([i > p[-k1] for i in p])) else: print(k)