import sys input = sys.stdin.readline N = int(input()) a = list(map(int, input().split())) one = a.count(1) two = a.count(2) other = N - one - two print((two + other) * (two + other - 1) // 2 + two * other + one * (one - 1) + one * two * 3)