N = int(input()) lsA = list(map(int,input().split())) cntall = N*(N+1)//2 cnt = 0 for i in range(N): if lsA[i] == 1: cnt += 1 else: cntall -= cnt*(cnt+1)//2 cnt = 0 cntall -= cnt*(cnt+1)//2 print(cntall)