N = int(input()) A = list(map(int, input().split())) cnt = [0]*4 for a in A: if a >= 3: cnt[3] += 1 else: cnt[a] += 1 ans = 0 # mex(1, 2) = 3 ans += cnt[1] * cnt[2] * 3 # mex(1, 1) = 2 ans += (cnt[1]*(cnt[1] - 1)//2) * 2 # mex(1, 1と2以外) = 2 ans += (cnt[1] * (N - cnt[1] - cnt[2])) * 2 # mex(1以外, 1以外) = 1 N -= cnt[1] ans += N*(N - 1)//2 * 1 print(ans)