import java.util.Scanner; public class Main { public static void main(String[] args) throws Exception { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int[] a = new int[n]; for (int i = 0; i < n; i++) { a[i] = sc.nextInt(); } sc.close(); long ans = (long) n * (n + 1) / 2; int cnt = 0; for (int i = 0; i < n; i++) { if (a[i] == 1) { cnt++; } else { ans -= (long) cnt * (cnt + 1) / 2; cnt = 0; } } ans -= (long) cnt * (cnt + 1) / 2; System.out.println(ans); } }