import sys
input = sys.stdin.readline
N = int(input())
a = list(map(int, input().split()))
one = a.count(1)
two = a.count(2)
other = N - one - two

print((two + other) * (two + other - 1) // 2 + two * other + one * (one - 1) + one * other * 2 + one * two * 3)