from collections import Counter from sys import stdin def main(): input = lambda: stdin.readline()[:-1] N = int(input()) A = map(int, input().split()) nums = sorted(Counter(A).items()) one, two = 0, 0 if nums[0][0] == 1: one = nums[0][1] if len(nums) > 1 and nums[1][0] == 2: two = nums[1][1] other = N - one - two ans = (one - 1) * one // 2 * 2 ans += one * two * 3 ans += one * other * 2 n = N - one ans += (n - 1) * n // 2 print(ans) main()