n=int(input()) a=list(map(int,input().split())) l=[0] for v in a: if v: l[-1]+=1 else: l+=[0] print(n*(n+1)//2-sum(v*(v+1)//2 for v in l))