import sys import math from typing import List, Tuple, Optional import sys MOD = 998244353 PRIMITIVE_ROOT = 3 # ---------- ModInt class ---------- class ModInt: def __init__(self, x=0): if isinstance(x, ModInt): self.x = x.x else: self.x = x % MOD if self.x < 0: self.x += MOD def __int__(self): return self.x def __str__(self): return str(self.x) def __repr__(self): return f"ModInt({self.x})" def __eq__(self, other): if isinstance(other, ModInt): return self.x == other.x return self.x == (other % MOD) def __ne__(self, other): return not self.__eq__(other) def __add__(self, other): if isinstance(other, ModInt): return ModInt(self.x + other.x) return ModInt(self.x + other) def __radd__(self, other): return self.__add__(other) def __sub__(self, other): if isinstance(other, ModInt): return ModInt(self.x - other.x) return ModInt(self.x - other) def __rsub__(self, other): if isinstance(other, ModInt): return ModInt(other.x - self.x) return ModInt(other - self.x) def __mul__(self, other): if isinstance(other, ModInt): return ModInt(self.x * other.x) return ModInt(self.x * other) def __rmul__(self, other): return self.__mul__(other) def __truediv__(self, other): if isinstance(other, ModInt): return self * other.inv() return self * ModInt(other).inv() def __rtruediv__(self, other): if isinstance(other, ModInt): return other * self.inv() return ModInt(other) * self.inv() def __pow__(self, exp): if exp == 0: return ModInt(1) if exp < 0: return self.inv() ** (-exp) res = ModInt(1) base = self while exp: if exp & 1: res *= base base *= base exp >>= 1 return res def inv(self): # Fermat's little theorem return self ** (MOD - 2) def __neg__(self): return ModInt(-self.x) def __hash__(self): return hash(self.x) def __lt__(self, other): if isinstance(other, ModInt): return self.x < other.x return self.x < other def __le__(self, other): if isinstance(other, ModInt): return self.x <= other.x return self.x <= other def __gt__(self, other): return not self.__le__(other) def __ge__(self, other): return not self.__lt__(other) # ---------- Combination class ---------- class Combination: def __init__(self): self.fact = [ModInt(1)] self.ifact = [ModInt(1)] self.inv_list = [ModInt(0)] self.N = 0 def build(self, n): if n <= self.N: return self.fact = self.fact[:] + [ModInt(0)] * (n - self.N) self.ifact = self.ifact[:] + [ModInt(0)] * (n - self.N) self.inv_list = self.inv_list[:] + [ModInt(0)] * (n - self.N) for i in range(max(1, self.N), n): self.fact[i] = self.fact[i-1] * i self.ifact[n-1] = self.fact[n-1].inv() for i in range(n-1, max(1, self.N), -1): self.ifact[i-1] = self.ifact[i] * i self.inv_list[i] = self.ifact[i] * self.fact[i-1] self.N = n def fac(self, k): if k >= self.N: self.build(k+1) return self.fact[k] def ifac(self, k): if k >= self.N: self.build(k+1) return self.ifact[k] def inv(self, k): if k >= self.N: self.build(k+1) return self.inv_list[k] def comb(self, a, b): if a < b or b < 0: return ModInt(0) return self.fac(a) * self.ifac(a-b) * self.ifac(b) def perm(self, a, b): if a < b or b < 0: return ModInt(0) return self.fac(a) * self.ifac(a-b) # ---------- FFT Info ---------- class FFTInfo: def __init__(self, g=PRIMITIVE_ROOT): self.rank2 = (MOD - 1).bit_length() - 1 self.root = [ModInt(0)] * (self.rank2 + 2) self.iroot = [ModInt(0)] * (self.rank2 + 2) self.root[self.rank2] = ModInt(g) ** ((MOD - 1) >> self.rank2) self.iroot[self.rank2] = self.root[self.rank2].inv() for i in range(self.rank2 - 1, -1, -1): self.root[i] = self.root[i+1] * self.root[i+1] self.iroot[i] = self.iroot[i+1] * self.iroot[i+1] max_len = max(0, self.rank2 - 2 + 1) self.rate2 = [ModInt(0)] * max_len self.irate2 = [ModInt(0)] * max_len max_len3 = max(0, self.rank2 - 3 + 1) self.rate3 = [ModInt(0)] * max_len3 self.irate3 = [ModInt(0)] * max_len3 prod = ModInt(1) iprod = ModInt(1) for i in range(max_len): self.rate2[i] = self.root[i+2] * prod self.irate2[i] = self.iroot[i+2] * iprod prod *= self.iroot[i+2] iprod *= self.root[i+2] prod = ModInt(1) iprod = ModInt(1) for i in range(max_len3): self.rate3[i] = self.root[i+3] * prod self.irate3[i] = self.iroot[i+3] * iprod prod *= self.iroot[i+3] iprod *= self.root[i+3] # ---------- FFT Functions ---------- def butterfly(a: List[ModInt]): n = len(a) h = (n - 1).bit_length() - 1 info = FFTInfo() length = 0 while length < h: if h - length == 1: p = 1 << (h - length - 1) rot = ModInt(1) for s in range(1 << length): offset = s << (h - length) for i in range(p): l = a[i + offset] r = a[i + offset + p] * rot a[i + offset] = l + r a[i + offset + p] = l - r if s + 1 != (1 << length): rot *= info.rate2[((~s) & (~s + 1)).bit_length() - 1] length += 1 else: p = 1 << (h - length - 2) rot = ModInt(1) imag = info.root[2] for s in range(1 << length): rot2 = rot * rot rot3 = rot2 * rot offset = s << (h - length) for i in range(p): a0 = a[i + offset] a1 = a[i + offset + p] * rot a2 = a[i + offset + 2 * p] * rot2 a3 = a[i + offset + 3 * p] * rot3 a1na3imag = (a1 - a3) * imag a[i + offset] = a0 + a2 + a1 + a3 a[i + offset + p] = a0 + a2 - a1 - a3 a[i + offset + 2 * p] = a0 - a2 + a1na3imag a[i + offset + 3 * p] = a0 - a2 - a1na3imag if s + 1 != (1 << length): rot *= info.rate3[((~s) & (~s + 1)).bit_length() - 1] length += 2 def butterfly_inv(a: List[ModInt]): n = len(a) h = (n - 1).bit_length() - 1 info = FFTInfo() length = h while length: if length == 1: p = 1 << (h - length) irot = ModInt(1) for s in range(1 << (length - 1)): offset = s << (h - length + 1) for i in range(p): l = a[i + offset] r = a[i + offset + p] a[i + offset] = l + r a[i + offset + p] = (l - r) * irot if s + 1 != (1 << (length - 1)): irot *= info.irate2[((~s) & (~s + 1)).bit_length() - 1] length -= 1 else: p = 1 << (h - length) irot = ModInt(1) iimag = info.iroot[2] for s in range(1 << (length - 2)): irot2 = irot * irot irot3 = irot2 * irot offset = s << (h - length + 2) for i in range(p): a0 = a[i + offset] a1 = a[i + offset + p] a2 = a[i + offset + 2 * p] a3 = a[i + offset + 3 * p] a2na3iimag = (a2 - a3) * iimag a[i + offset] = a0 + a1 + a2 + a3 a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot a[i + offset + 2 * p] = (a0 + a1 - a2 - a3) * irot2 a[i + offset + 3 * p] = (a0 - a1 - a2na3iimag) * irot3 if s + 1 != (1 << (length - 2)): irot *= info.irate3[((~s) & (~s + 1)).bit_length() - 1] length -= 2 def convolution_fft(a: List[ModInt], b: List[ModInt]) -> List[ModInt]: n, m = len(a), len(b) z = 1 << (n + m - 1 - 1).bit_length() a = a + [ModInt(0)] * (z - n) b = b + [ModInt(0)] * (z - m) butterfly(a) butterfly(b) for i in range(z): a[i] *= b[i] butterfly_inv(a) iz = ModInt(z).inv() res = [a[i] * iz for i in range(n + m - 1)] return res def convolution_mod(a: List[ModInt], b: List[ModInt]) -> List[ModInt]: if not a or not b: return [] n, m = len(a), len(b) if min(n, m) <= 60: # Naive convolution for small cases res = [ModInt(0)] * (n + m - 1) if n < m: for j in range(m): for i in range(n): res[i + j] += a[i] * b[j] else: for i in range(n): for j in range(m): res[i + j] += a[i] * b[j] return res return convolution_fft(a, b) # ---------- Counting Path on Grid ---------- class CountingPathOnGrid: F = [[] for _ in range(25)] # F[k] = ntt(0! + 1!x + ... + (2^k-1)!x^(2^k-1)) @classmethod def make_f(cls, k: int): comb = Combination() comb.build(1 << k) cls.F[k] = [comb.fac(i) for i in range(1 << k)] butterfly(cls.F[k]) @classmethod def solve(cls, H: int, W: int, w: List[ModInt]) -> List[ModInt]: assert len(w) == H + W - 1 if not H or not W: return [] comb = Combination() comb.build(H + W) logHW = (H + W - 1).bit_length() A = [ModInt(0)] * (1 << logHW) iz = ModInt(1 << logHW).inv() # Check if we need to use bottom paths use_sita = any(w[H - 1 + i] != 0 for i in range(1, W)) res = [ModInt(0)] * (H + W - 1) # Left -> Up for i in range(H): A[i] = w[H - 1 - i] * comb.ifac(H - 1 - i) butterfly(A) if not cls.F[logHW]: cls.make_f(logHW) for i in range(1 << logHW): A[i] *= cls.F[logHW][i] butterfly_inv(A) for i in range(W): j = H - 1 + i res[i] += A[j] * iz * comb.ifac(i) # Bottom -> Right if use_sita: A = [ModInt(0)] * (1 << logHW) for i in range(1, W): A[i] = w[H - 1 + i] * comb.ifac(W - 1 - i) butterfly(A) for i in range(1 << logHW): A[i] *= cls.F[logHW][i] butterfly_inv(A) for i in range(H): res[H + W - 2 - i] += A[W - 1 + i] * iz * comb.ifac(i) # Left -> Right B = w[:H] C = [comb.comb(H + W - 2 - i, W - 1) for i in range(H)] conv = convolution_mod(B, C) for i in range(1, H): res[W - 1 + i] += conv[H - 1 + i] # Bottom -> Up if use_sita: B = [ModInt(0)] * W for i in range(W - 1): B[i] = w[H + W - 2 - i] C = [comb.comb(H + W - 2 - i, H - 1) for i in range(W)] conv = convolution_mod(B, C) for i in range(1, W): res[W - 1 - i] += conv[W - 1 + i] return res # ---------- Main Algorithm ---------- def number_of_increasing_sequence(A: List[int], B: List[int]) -> List[ModInt]: N = len(A) if N == 0: return [] D = [ModInt(0)] * (N + 1) L = [ModInt(0)] * B[-1] L[A[0]] = ModInt(1) C = [0] * N def dfs_inc(l: int, r: int, d: int): if d == C[r - 1]: return if r - l == 1: L[d] += D[l] for i in range(d + 1, C[l]): L[i] += L[i - 1] D[l] = L[C[l] - 1] return m = (l + r) // 2 if l < m: dfs_inc(l, m, d) y = C[m] Hi = y - d Wi = r - m if Hi and Wi: w = [ModInt(0)] * (Hi + Wi - 1) for i in range(Hi): w[i] = L[y - 1 - i] for i in range(Wi): w[Hi - 1 + i] += D[m + i] w = CountingPathOnGrid.solve(Hi, Wi, w) for i in range(Wi): D[m + i] = w[i] for i in range(Hi): L[y - 1 - i] = w[Wi - 1 + i] if m < r: dfs_inc(m, r, y) def dfs_dec(l: int, r: int, d: int): if d == C[l]: return if r - l == 1: L[C[l] - 1] += D[l] for i in range(C[l] - 2, d - 1, -1): L[i] += L[i + 1] D[l] = L[d] for i in range(d, C[l]): D[l + 1] += L[i] return m = (l + r) // 2 y = C[m - 1] if l < m: dfs_dec(l, m, y) Hi = y - d Wi = m - l if Hi and Wi: w = [ModInt(0)] * (Hi + Wi - 1) for i in range(Hi): w[i] = L[d + i] for i in range(Wi): w[Hi - 1 + i] += D[l + i] w = CountingPathOnGrid.solve(Hi, Wi, w) for i in range(Hi): L[d + i] = w[Wi - 1 + i] for i in range(Wi): D[l + i] = w[i] for i in range(C[m], y): D[m] += L[i] if m < r: dfs_dec(m, r, d) low_sum = ModInt(0) l = 0 while l < N: r = l while r < N and (r == l or A[r] < B[l]): r += 1 D[l] = low_sum low_sum = ModInt(0) for i in range(l, r): C[i] = B[l] - A[i] + A[l] # Reverse L[A[l]:B[l]] segment = L[A[l]:B[l]] segment.reverse() for i, val in enumerate(segment): L[A[l] + i] = val dfs_dec(l, r, A[l]) for i in range(A[l], C[r - 1]): low_sum += L[i] # Reverse back segment = L[A[l]:B[l]] segment.reverse() for i, val in enumerate(segment): L[A[l] + i] = val for i in range(l, r): C[i] = B[i] dfs_inc(l, r, B[l]) for i in range(B[l], A[r] if r < N else 0): low_sum += L[i] l = r return L # ---------- Main Function ---------- def main(): S = sys.stdin.readline().strip() B = [] cnt = 0 for i, ch in enumerate(S): if ch == 'A': B.append(i + 1 - cnt) cnt += 1 N = len(B) A = [0] * N if N == 0: print(1) return for i in range(1, N): A[i] = max(A[i], A[i - 1]) ans = number_of_increasing_sequence(A, B) s = ModInt(0) for i in range(A[-1], B[-1]): s += ans[i] print(s) if __name__ == "__main__": main()