n = int(input()) a = list(map(int, input().split())) one = a.count(1) two = a.count(2) ans = 0 ans += 3*one*two ans += 2*one*(one-1)//2 ans += 2*one*(n-one-two) ans += (n-one)*(n-one-1)//2 print(ans)