n = int(input()) A = list(map(int, input().split())) A.sort() c = 0 ans = 0 for i, a in enumerate(A): if a == 1: ans += c*2 c += 1 elif a == 2: ans += c*3 ans += (i-c)*1 else: ans += c*2 ans += (i-c)*1 print(ans)