#!/usr/bin/python if __name__ == '__main__': N = int(input()) S = [input() for i in range(N)] a = [0 for i in range(N - 1)] for i in range(N - 1): tmp = min(len(S[i]), len(S[i + 1])) for j in range(tmp): if S[i][j] != S[i + 1][j]: break a[i] = j + 1 ans = 0 for s in S: ans += len(s) nxt = [len(a) for i in range(len(a))] for i in range(len(a) - 1, -1, -1): now = i + 1 while now < len(a): if a[now] < a[i]: break now = nxt[now] nxt[i] = now for i in range(len(a)): now = i while now < len(a) and a[now] > 0: S, T = now - i + 2, nxt[now] - i + 1 L = (S + T) * (T - S + 1) // 2 ans += L * a[now] now = nxt[now] print(ans)