n = int(input()) a = list(map(int, input().split())) one, two = a.count(1), a.count(2) oth = n - (one + two) print((one * two) * 3 + (one * (one - 1) // 2) * 2 + (n * (n - 1) // 2 - (one * two + one * (one - 1) // 2)))