N=int(input()) A=list(map(int,input().split())) cnt1=A.count(1) cnt2=A.count(2) ans=0 ans+=cnt1*(cnt1-1)+3*cnt1*cnt2+2*cnt1*(N-cnt1-cnt2)+(N-cnt1)*(N-cnt1-1)//2 print(ans)