N = int(input()) A = list(map(int, input().split())) cnt1 = 0 cnt2 = 0 for a in A: if a == 1: cnt1 += 1 elif a == 2: cnt2 += 1 res = (cnt1 - 1) * cnt1 // 2 * 2 + cnt1 * cnt2 * 3 + N * (N - 1) // 2 - cnt1 * cnt2 - (cnt1 - 1) * cnt1 // 2 print(res)