N = int(input()) A = list(map(int, input().split())) cnt1 = 0 cnt2 = 0 for a in A: if a == 1: cnt1 += 1 elif a == 2: cnt2 += 1 mex2 = (cnt1 - 1) * cnt1 // 2 + cnt1 * (N - cnt1 - cnt2) mex3 = cnt1 * cnt2 mex1 = N * (N - 1) // 2 - mex2 - mex3 res = mex1 + 2 * mex2 + 3 * mex3 print(res)