def main(): import sys input = sys.stdin.read data = input().split() N = int(data[0]) A = list(map(int, data[1:N+1])) max_x = max(A) + 2 pos = {} for i, num in enumerate(A): if num not in pos: pos[num] = [] pos[num].append(i) total = 0 for x in range(1, max_x + 1): if x == 1: regions = [] prev = -1 for i in range(N): if A[i] == 1: regions.append((prev + 1, i - 1)) prev = i regions.append((prev + 1, N - 1)) count = 0 for l, r in regions: if l > r: continue k = r - l + 1 count += k * (k + 1) // 2 total += 1 * count else: if x not in pos: regions = [(0, N - 1)] else: regions = [] prev = -1 for p in pos[x]: regions.append((prev + 1, p - 1)) prev = p regions.append((prev + 1, N - 1)) count_x = 0 for l, r in regions: if l > r: continue required = set(range(1, x)) left = l right = l current = set() while right <= r: num = A[right] if num in required: current.add(num) while len(current) == len(required) and left <= right: count_x += (r - right + 1) num_left = A[left] if num_left in required: current.remove(num_left) left += 1 right += 1 total += x * count_x print(total) if __name__ == "__main__": main()