MOD = 998244353 def solve(S): pos = [] cnt = 0 for i, ch in enumerate(S): if ch == 'A': pos.append(i + 1 - cnt) cnt += 1 n = len(pos) if n == 0: return 1 maxB = pos[-1] dp = [0] * (maxB + 1) dp[0] = 1 for b in pos: ndp = [0] * (maxB + 1) pref = 0 for x in range(b): pref = (pref + dp[x]) % MOD ndp[x] = pref dp = ndp return sum(dp[:pos[-1]]) % MOD