n = int(input()) a = list(map(int, input().split())) ans = 0 tmp = 0 for i in range(n): if a[i]: tmp += 1 else: ans += tmp*(tmp+1)//2 tmp = 0 if a[-1]: ans += tmp*(tmp+1)//2 print(n*(n+1)//2 - ans)