N = int(input()) A = list(map(int, input().split())) cnt1 = A.count(1) cnt2 = A.count(2) M = N * (N - 1) // 2 ans = 3 * cnt1 * cnt2 M -= cnt1 * cnt2 ans += 2 * cnt1 * (cnt1 - 1) // 2 M -= cnt1 * (cnt1 - 1) // 2 print(ans + M)