from collections import defaultdict def main(): S = input() cnt = [0] * (len(S) + 1) for i, s in enumerate(S): if s == 'A': cnt[i+1] = cnt[i] + 1 else: cnt[i+1] = cnt[i] - 1 num = defaultdict(list) for i, v in enumerate(cnt): num[v].append(i) ans = 0 for v in num.values(): if len(v) > 1: ans = max(ans, max(v) - min(v)) print(ans) main()