def main(): n = int(input()) a = list(map(int, input().split())) pL = -1 cN = 0 for i in range(n): if a[i] == 0: pL = i cN += i - pL print(sum(range(n + 1)) - cN) main()