N = int(input())
lsA = list(map(int,input().split()))
ans = 0
c1 = lsA.count(1)
c2 = lsA.count(2)
ce = N-c1-c2

v3 = c1*c2
v2 = c1*(c1-1)//2 + c1*ce
v1 = c2*(c2-1)//2 + c2*ce + ce*(ce-1)//2
print(v1+2*v2+3*v3)