#include #include using lint = long long; void solve() { int n, k; std::cin >> n >> k; --k; std::vector xs(n); for (auto& x : xs) std::cin >> x; lint l = 0; for (int i = k - 1; i >= 0; --i) { l += xs[i]; if (xs[i] <= 1) break; } lint r = 0; for (int i = k + 1; i < n; ++i) { r += xs[i]; if (xs[i] <= 1) break; } lint ans; if (xs[k] == 0) { ans = 0; } else if (xs[k] == 1) { ans = std::max(l, r) + 1; } else { ans = l + r + xs[k]; } std::cout << ans << std::endl; } int main() { std::cin.tie(nullptr); std::cout.tie(nullptr); std::ios::sync_with_stdio(false); solve(); return 0; }