def main(): n = int(input()) p = list(map(int,input().split())) bit = list(map(int, [0 for i in range(n+1)])) def sum(x): res = 0 while x > 0: res = res + bit[x] x = x - (x&-x) return res def add(x, y): if(x == 0): return while x <= n: bit[x] = bit[x] + y x = x + (x&-x) ans = 1 x = 1 for i in range(1, n+1): ans = x * sum(p[n-i]) + ans add(p[n-i], 1) x = x * i print(ans) main()