import sys input = sys.stdin.readline MOD = 998244353 PRIMITIVE_ROOT = 3 # ----------------- ModInt ----------------- class ModInt: __slots__ = ['x'] def __init__(self,x): self.x = x % MOD def __add__(self,other): return ModInt(self.x + (other.x if isinstance(other,ModInt) else other)) def __sub__(self,other): return ModInt(self.x - (other.x if isinstance(other,ModInt) else other)) def __mul__(self,other): return ModInt(self.x * (other.x if isinstance(other,ModInt) else other)) def __pow__(self,power): return ModInt(pow(self.x,power,MOD)) def inv(self): return ModInt(pow(self.x,MOD-2,MOD)) def __truediv__(self,other): return self*self.inv() if isinstance(other,ModInt) else self*ModInt(other).inv() def __iadd__(self,other): self.x = (self.x + (other.x if isinstance(other,ModInt) else other)) % MOD return self def __repr__(self): return str(self.x) def val(self): return self.x # ----------------- NTT / Convolution ----------------- def modinv(x): return pow(x,MOD-2,MOD) def ntt(a,invert): n = len(a) j = 0 for i in range(1,n): bit = n >> 1 while j & bit: j ^= bit bit >>= 1 j ^= bit if i < j: a[i],a[j] = a[j],a[i] length = 2 while length <= n: wlen = pow(PRIMITIVE_ROOT,(MOD-1)//length,MOD) if invert: wlen = modinv(wlen) for i in range(0,n,length): w = 1 for j in range(i,i+length//2): u = a[j] v = a[j+length//2]*w % MOD a[j] = (u+v)%MOD a[j+length//2] = (u-v+MOD)%MOD w = w*wlen%MOD length <<= 1 if invert: ninv = modinv(n) for i in range(n): a[i] = a[i]*ninv%MOD def convolution(a,b): n = 1 while n < len(a)+len(b)-1: n <<= 1 fa = [x.x if isinstance(x,ModInt) else x for x in a]+[0]*(n-len(a)) fb = [x.x if isinstance(x,ModInt) else x for x in b]+[0]*(n-len(b)) ntt(fa,False) ntt(fb,False) for i in range(n): fa[i] = fa[i]*fb[i]%MOD ntt(fa,True) return [ModInt(x) for x in fa[:len(a)+len(b)-1]] # ----------------- Binomial ----------------- class Binomial: def __init__(self, MAX): self.fact = [ModInt(1)]*(MAX+1) self.invfact = [ModInt(1)]*(MAX+1) for i in range(1,MAX+1): self.fact[i] = self.fact[i-1]*i self.invfact[MAX] = self.fact[MAX].inv() for i in range(MAX,0,-1): self.invfact[i-1] = self.invfact[i]*i def C(self,n,k): if n<0 or k<0 or k>n: return ModInt(0) return self.fact[n]*self.invfact[k]*self.invfact[n-k] # ----------------- Count increasing sequences (D&C + convolution) ----------------- def count_increase_sequences_with_upper_bounds(A,C): n = len(A) if n==0: return [] if n==1: return [C[0]]*A[0] m = n//2 LA,LC = A[:m],C[:m] RA,RC = [A[i]-A[m-1] for i in range(m,n)],C[m:] Lseq = count_increase_sequences_with_upper_bounds(LA,LC) # Right convolution binom = Binomial(sum(RA)+len(RA)+10) tmp = [binom.C(RA[0]-1+i,i) for i in range(len(Lseq))] conv_res = convolution(tmp,Lseq) # Naive add (sum if needed) res = [ModInt(0)]*(max(len(conv_res),RA[-1])) for i in range(len(conv_res)): res[i] = conv_res[i] return res # ----------------- Solve ----------------- def solve(): s = input().strip() n = len(s) a = s.count('A') b = s.count('B') if a==0 or b==0: print(1) return lb,ub = [],[] h = 0 for c in s: if c=='A': lb.append(0) ub.append(h+1) else: h += 1 # C vector = all ones C = [ModInt(1) for _ in range(len(lb))] tmp = count_increase_sequences_with_upper_bounds(ub,C) ans = ModInt(0) for x in tmp: ans += x print(ans.val()) if __name__=="__main__": solve()