N = int(input()) A = list(map(int, input().split())) print(A) B = [int(s) for s in A if s>2] C = [int(s) for s in A if s==2] D = [int(s) for s in A if s<2] b = len(B) c = len(C) d = len(D) mex = (b+c)*(b+c-1)/2 +d*(d+2*b+3*c-1) print(int(mex))