import java.util.*; public class Main { static final int MOD = 1000000007; public static void main(String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); HashMap map = new HashMap<>(); for (int i = 0; i < n; i++) { int x = sc.nextInt(); map.put(x, map.getOrDefault(x, 0) + 1); } long[] comb3 = new long[n + 1]; long[] comb2 = new long[n + 1]; for (int i = 3; i <= n; i++) { comb3[i] = 1; for (long j = 0; j < 3; j++) { comb3[i] *= i - j; } comb3[i] /= 6; comb2[i] = 1; for (long j = 0; j < 2; j++) { comb2[i] *= i - j; } comb2[i] /= 2; } comb2[2] = 1; long ans = comb3[n]; for (int x : map.values()) { ans -= comb3[x]; ans -= comb2[x] * (n - x); } System.out.println(ans % MOD); } }