N = int(input()) lsA = list(map(int,input().split())) ans = 0 c1 = lsA.count(1) c2 = lsA.count(2) ce = N-c1-c2 v3 = c1*c2 v2 = c1*(c1-1)//2 + c1*ce v1 = c2*(c2-1)//2 + c2*ce + ce*(ce-1)//2 print(v1+2*v2+3*v3)