N = int(input()) A = list(map(int,input().split())) n1 = 0 n2 = 0 for i in range(N): if A[i]==1: n1 += 1 elif A[i]==2: n2 += 1 ans = 0 ans += (N-n1)*(N-n1-1)//2 ans += n1*(n1-1) ans += n1*n2*3 ans += n1*(N-n1-n2)*2 print(ans)