#yuki1505 n=int(input()) a=list(map(int,input().split())) res=n*(n+1)//2 x=0 for i in range(n): if a[i]: x+=1 res-=x else: x=0 print(res)