def main(): N = int(input()) A = list(map(int, input().split())) num_1_in_A = A.count(1) num_2_in_A = A.count(2) num_other_in_A = N - num_1_in_A - num_2_in_A sum_mex = 0 if num_1_in_A > 1: sum_mex += num_1_in_A * (num_1_in_A - 1) sum_mex += 3 * num_1_in_A * num_2_in_A sum_mex += 2 * num_1_in_A * num_other_in_A if num_2_in_A > 1: sum_mex += num_2_in_A * (num_2_in_A - 1) // 2 sum_mex += num_2_in_A * num_other_in_A if num_other_in_A > 1: sum_mex += num_other_in_A * (num_other_in_A - 1) // 2 print(sum_mex) if __name__ == "__main__": main()