#include #include #include using lint = long long; void solve() { int n; std::cin >> n; std::vector xs(n); for (auto& x : xs) std::cin >> x; std::sort(xs.begin(), xs.end()); std::vector lens(n, 1); for (int i = n - 1; i >= 0; --i) { for (int j = 1; j <= 2 && i + j < n; ++j) { if (xs[i + j] == xs[i] + 2) lens[i] = lens[i + j] + 1; } } lint ans = n; for (int i = 0; i + 1 < n; ++i) { if (xs[i + 1] == xs[i] + 1) ans += lens[i + 1]; } std::cout << ans << "\n"; } int main() { std::cin.tie(nullptr); std::ios::sync_with_stdio(false); solve(); return 0; }