from collections import defaultdict dic = defaultdict(list) S = list(input()) for i, s in enumerate(S, start=1): dic[s].append(i) N = len(S) ans = 0 for k, V in dic.items(): n = len(V) for j, v in enumerate(V, start=1): if j > 1: ans += (j - 1) * (N - v - n + j) print(ans)