n, k = map(int, input().split()) alst = list(map(int, input().split())) k -= 1 if alst[k] == 0: print(0) exit() pos_l = k - 1 while 1: if alst[pos_l] < 2 or pos_l == 0: break pos_l -= 1 pos_r = k + 1 while 1: if alst[pos_r] < 2 or pos_r == n - 1: break pos_r += 1 pos_r += 1 l_sum = sum(alst[pos_l:k]) r_sum = sum(alst[k + 1:pos_r]) if alst[k] == 1: print(max(l_sum, r_sum) + 1) else: print(l_sum + r_sum + alst[k])