n = int(input()) a = list(map(int,input().split())) c = [0]*3 for x in a: if x <= 2: c[x] += 1 ans = 0 s = n for x in a: s -= 1 if x <= 2: c[x] -= 1 if x == 1: ans += 2*(s-c[2]) ans += 3*c[2] elif x == 2: ans += 1*(s-c[1]) ans += 3*c[1] else: ans += 1*(s-c[1]) ans += 2*c[1] print(ans)