n = int(input())
*a, = map(int,input().split())
r = [0,0,0,0]
for i in a:
    r[min(i,3)] += 1
ans = r[1]*(r[1]-1)//2 + r[1]*r[3] + 2*r[1]*r[2] + n*(n-1)//2
print(ans)