def main(): n = int(input()) p = list(map(int, input().split())) bit = [0] * (n+1) ans = 1 a = 1 for i in range(1, n+1): x = p[n-i] y = x v = 0 while x > 0 or y <= n: if x > 0 : v = v + bit[x] if y <= n: bit[y] = bit[y] + 1 x = x - (x&-x) y = y + (y&-y) t = a * v ans= t + ans a = a * i print(ans) main()