import java.util.*; public class Main { public static void main(String[] args) throws Exception { Scanner sc = new Scanner(System.in); int N = sc.nextInt(); long ans = 0; long[] cnt = new long[3]; for( int i = 0; i < N; i++ ) { int a = sc.nextInt()-1; cnt[Math.min(2,a)]++; } ans += cnt[0]*(cnt[0]+cnt[2]-1); ans += cnt[0]*cnt[1]*3; ans += (cnt[1]+cnt[2])*(cnt[1]+cnt[2]-1)/2; System.out.println(ans); } }