import java.util.*; public class Main { public static void main(String[] args) { Scanner sc = new Scanner(System.in); int N = sc.nextInt(); long ans = 0; long[] cnt = new long[3]; for( int i = 0; i < N; i++ ) { int a = sc.nextInt()-1; cnt[Math.min(2,a)]++; } ans += cnt[0]*(cnt[0]-1)+cnt[0]*cnt[2]*2; ans += cnt[0]*cnt[1]*3; ans += (cnt[1]+cnt[2])*(cnt[1]+cnt[2]-1)/2; System.out.println(ans); } }