import java.util.*; public class Main { public static void main (String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); TreeMap stock = new TreeMap<>(); ArrayDeque zeros = new ArrayDeque<>(); long total = 0; for (int i = 0; i < n; i++) { int x = sc.nextInt(); if (x == 0) { if (stock.size() == 0) { zeros.push(i); } else { int key = stock.lastKey(); total += i - key; if (stock.get(key) == 1) { stock.remove(key); } else { stock.put(key, stock.get(key) - 1); } } } else if (x > 1) { for (int j = 1; j < x; j++) { if (zeros.size() == 0) { stock.put(i, x - j); break; } else { total += i - zeros.pop(); } } } } System.out.println(total); } }