import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int k = sc.nextInt(); int[] arr = new int[n + 2]; for (int i = 1; i <= n; i++) { arr[i] = sc.nextInt(); } if (arr[k] == 0) { System.out.println(0); return; } long left = 0; int idx = k - 1; while (arr[idx] > 1) { left += arr[idx]; idx--; } if (arr[idx] == 1) { left++; } long right = 0; idx = k + 1; while (arr[idx] > 1) { right += arr[idx]; idx++; } if (arr[idx] == 1) { right++; } if (arr[k] == 1) { System.out.println(Math.max(left, right) + 1); } else { System.out.println(left + right + arr[k]); } } }