n = int(input()) a = list(map(int, input().split())) p = a.count(1) q = a.count(2) ans = n * (n - 1) // 2 + p * (n - 1) - p * (p - 1) // 2 + p * q print(ans)