def mex(a,b): num = 0 if b == 2: num = b+1 else: num = 2 return num N = int(input()) A = list(map(int,input().split())) A.sort() c1 = A.count(1) ans = max(0,c1*(c1-1)) num = A[0] for i in range(c1,N): ans += (mex(num,A[i])*c1) ans += max(0,(N-c1)*(N-c1-1)) print(ans)