#include <iostream>
#include <algorithm>
#include <vector>

using lint = long long;

void solve() {
    int n;
    std::cin >> n;

    std::vector<lint> xs(n);
    for (auto& x : xs) std::cin >> x;
    std::sort(xs.begin(), xs.end());

    std::vector<lint> lens(n, 1);
    for (int i = n - 1; i >= 0; --i) {
        for (int j = 1; j <= 2 && i + j < n; ++j) {
            if (xs[i + j] == xs[i] + 2) lens[i] = lens[i + j] + 1;
        }
    }

    lint ans = n;
    for (int i = 0; i + 1 < n; ++i) {
        if (xs[i + 1] == xs[i] + 1) ans += lens[i + 1];
    }
    std::cout << ans << "\n";
}

int main() {
    std::cin.tie(nullptr);
    std::ios::sync_with_stdio(false);

    solve();

    return 0;
}