結果

問題 No.3370 AB → BA
コンテスト
ユーザー Aralov Otabek
提出日時 2026-01-08 22:59:38
言語 PyPy3
(7.3.17)
結果
TLE  
実行時間 -
コード長 15,982 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 285 ms
コンパイル使用メモリ 82,296 KB
実行使用メモリ 276,680 KB
最終ジャッジ日時 2026-01-08 22:59:48
合計ジャッジ時間 4,591 ms
ジャッジサーバーID
(参考情報)
judge4 / judge2
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other TLE * 3 -- * 17
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys
import math
from typing import List, Tuple, Optional
import sys

MOD = 998244353
PRIMITIVE_ROOT = 3

# ---------- ModInt class ----------
class ModInt:
    def __init__(self, x=0):
        if isinstance(x, ModInt):
            self.x = x.x
        else:
            self.x = x % MOD
        if self.x < 0:
            self.x += MOD
    
    def __int__(self):
        return self.x
    
    def __str__(self):
        return str(self.x)
    
    def __repr__(self):
        return f"ModInt({self.x})"
    
    def __eq__(self, other):
        if isinstance(other, ModInt):
            return self.x == other.x
        return self.x == (other % MOD)
    
    def __ne__(self, other):
        return not self.__eq__(other)
    
    def __add__(self, other):
        if isinstance(other, ModInt):
            return ModInt(self.x + other.x)
        return ModInt(self.x + other)
    
    def __radd__(self, other):
        return self.__add__(other)
    
    def __sub__(self, other):
        if isinstance(other, ModInt):
            return ModInt(self.x - other.x)
        return ModInt(self.x - other)
    
    def __rsub__(self, other):
        if isinstance(other, ModInt):
            return ModInt(other.x - self.x)
        return ModInt(other - self.x)
    
    def __mul__(self, other):
        if isinstance(other, ModInt):
            return ModInt(self.x * other.x)
        return ModInt(self.x * other)
    
    def __rmul__(self, other):
        return self.__mul__(other)
    
    def __truediv__(self, other):
        if isinstance(other, ModInt):
            return self * other.inv()
        return self * ModInt(other).inv()
    
    def __rtruediv__(self, other):
        if isinstance(other, ModInt):
            return other * self.inv()
        return ModInt(other) * self.inv()
    
    def __pow__(self, exp):
        if exp == 0:
            return ModInt(1)
        if exp < 0:
            return self.inv() ** (-exp)
        res = ModInt(1)
        base = self
        while exp:
            if exp & 1:
                res *= base
            base *= base
            exp >>= 1
        return res
    
    def inv(self):
        # Fermat's little theorem
        return self ** (MOD - 2)
    
    def __neg__(self):
        return ModInt(-self.x)
    
    def __hash__(self):
        return hash(self.x)
    
    def __lt__(self, other):
        if isinstance(other, ModInt):
            return self.x < other.x
        return self.x < other
    
    def __le__(self, other):
        if isinstance(other, ModInt):
            return self.x <= other.x
        return self.x <= other
    
    def __gt__(self, other):
        return not self.__le__(other)
    
    def __ge__(self, other):
        return not self.__lt__(other)

# ---------- Combination class ----------
class Combination:
    def __init__(self):
        self.fact = [ModInt(1)]
        self.ifact = [ModInt(1)]
        self.inv_list = [ModInt(0)]
        self.N = 0
    
    def build(self, n):
        if n <= self.N:
            return
        self.fact = self.fact[:] + [ModInt(0)] * (n - self.N)
        self.ifact = self.ifact[:] + [ModInt(0)] * (n - self.N)
        self.inv_list = self.inv_list[:] + [ModInt(0)] * (n - self.N)
        
        for i in range(max(1, self.N), n):
            self.fact[i] = self.fact[i-1] * i
        
        self.ifact[n-1] = self.fact[n-1].inv()
        for i in range(n-1, max(1, self.N), -1):
            self.ifact[i-1] = self.ifact[i] * i
            self.inv_list[i] = self.ifact[i] * self.fact[i-1]
        
        self.N = n
    
    def fac(self, k):
        if k >= self.N:
            self.build(k+1)
        return self.fact[k]
    
    def ifac(self, k):
        if k >= self.N:
            self.build(k+1)
        return self.ifact[k]
    
    def inv(self, k):
        if k >= self.N:
            self.build(k+1)
        return self.inv_list[k]
    
    def comb(self, a, b):
        if a < b or b < 0:
            return ModInt(0)
        return self.fac(a) * self.ifac(a-b) * self.ifac(b)
    
    def perm(self, a, b):
        if a < b or b < 0:
            return ModInt(0)
        return self.fac(a) * self.ifac(a-b)

# ---------- FFT Info ----------
class FFTInfo:
    def __init__(self, g=PRIMITIVE_ROOT):
        self.rank2 = (MOD - 1).bit_length() - 1
        self.root = [ModInt(0)] * (self.rank2 + 2)
        self.iroot = [ModInt(0)] * (self.rank2 + 2)
        
        self.root[self.rank2] = ModInt(g) ** ((MOD - 1) >> self.rank2)
        self.iroot[self.rank2] = self.root[self.rank2].inv()
        
        for i in range(self.rank2 - 1, -1, -1):
            self.root[i] = self.root[i+1] * self.root[i+1]
            self.iroot[i] = self.iroot[i+1] * self.iroot[i+1]
        
        max_len = max(0, self.rank2 - 2 + 1)
        self.rate2 = [ModInt(0)] * max_len
        self.irate2 = [ModInt(0)] * max_len
        
        max_len3 = max(0, self.rank2 - 3 + 1)
        self.rate3 = [ModInt(0)] * max_len3
        self.irate3 = [ModInt(0)] * max_len3
        
        prod = ModInt(1)
        iprod = ModInt(1)
        for i in range(max_len):
            self.rate2[i] = self.root[i+2] * prod
            self.irate2[i] = self.iroot[i+2] * iprod
            prod *= self.iroot[i+2]
            iprod *= self.root[i+2]
        
        prod = ModInt(1)
        iprod = ModInt(1)
        for i in range(max_len3):
            self.rate3[i] = self.root[i+3] * prod
            self.irate3[i] = self.iroot[i+3] * iprod
            prod *= self.iroot[i+3]
            iprod *= self.root[i+3]

# ---------- FFT Functions ----------
def butterfly(a: List[ModInt]):
    n = len(a)
    h = (n - 1).bit_length() - 1
    
    info = FFTInfo()
    length = 0
    
    while length < h:
        if h - length == 1:
            p = 1 << (h - length - 1)
            rot = ModInt(1)
            for s in range(1 << length):
                offset = s << (h - length)
                for i in range(p):
                    l = a[i + offset]
                    r = a[i + offset + p] * rot
                    a[i + offset] = l + r
                    a[i + offset + p] = l - r
                if s + 1 != (1 << length):
                    rot *= info.rate2[((~s) & (~s + 1)).bit_length() - 1]
            length += 1
        else:
            p = 1 << (h - length - 2)
            rot = ModInt(1)
            imag = info.root[2]
            for s in range(1 << length):
                rot2 = rot * rot
                rot3 = rot2 * rot
                offset = s << (h - length)
                for i in range(p):
                    a0 = a[i + offset]
                    a1 = a[i + offset + p] * rot
                    a2 = a[i + offset + 2 * p] * rot2
                    a3 = a[i + offset + 3 * p] * rot3
                    
                    a1na3imag = (a1 - a3) * imag
                    a[i + offset] = a0 + a2 + a1 + a3
                    a[i + offset + p] = a0 + a2 - a1 - a3
                    a[i + offset + 2 * p] = a0 - a2 + a1na3imag
                    a[i + offset + 3 * p] = a0 - a2 - a1na3imag
                if s + 1 != (1 << length):
                    rot *= info.rate3[((~s) & (~s + 1)).bit_length() - 1]
            length += 2

def butterfly_inv(a: List[ModInt]):
    n = len(a)
    h = (n - 1).bit_length() - 1
    
    info = FFTInfo()
    length = h
    
    while length:
        if length == 1:
            p = 1 << (h - length)
            irot = ModInt(1)
            for s in range(1 << (length - 1)):
                offset = s << (h - length + 1)
                for i in range(p):
                    l = a[i + offset]
                    r = a[i + offset + p]
                    a[i + offset] = l + r
                    a[i + offset + p] = (l - r) * irot
                if s + 1 != (1 << (length - 1)):
                    irot *= info.irate2[((~s) & (~s + 1)).bit_length() - 1]
            length -= 1
        else:
            p = 1 << (h - length)
            irot = ModInt(1)
            iimag = info.iroot[2]
            for s in range(1 << (length - 2)):
                irot2 = irot * irot
                irot3 = irot2 * irot
                offset = s << (h - length + 2)
                for i in range(p):
                    a0 = a[i + offset]
                    a1 = a[i + offset + p]
                    a2 = a[i + offset + 2 * p]
                    a3 = a[i + offset + 3 * p]
                    
                    a2na3iimag = (a2 - a3) * iimag
                    a[i + offset] = a0 + a1 + a2 + a3
                    a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot
                    a[i + offset + 2 * p] = (a0 + a1 - a2 - a3) * irot2
                    a[i + offset + 3 * p] = (a0 - a1 - a2na3iimag) * irot3
                if s + 1 != (1 << (length - 2)):
                    irot *= info.irate3[((~s) & (~s + 1)).bit_length() - 1]
            length -= 2

def convolution_fft(a: List[ModInt], b: List[ModInt]) -> List[ModInt]:
    n, m = len(a), len(b)
    z = 1 << (n + m - 1 - 1).bit_length()
    
    a = a + [ModInt(0)] * (z - n)
    b = b + [ModInt(0)] * (z - m)
    
    butterfly(a)
    butterfly(b)
    
    for i in range(z):
        a[i] *= b[i]
    
    butterfly_inv(a)
    
    iz = ModInt(z).inv()
    res = [a[i] * iz for i in range(n + m - 1)]
    return res

def convolution_mod(a: List[ModInt], b: List[ModInt]) -> List[ModInt]:
    if not a or not b:
        return []
    
    n, m = len(a), len(b)
    if min(n, m) <= 60:
        # Naive convolution for small cases
        res = [ModInt(0)] * (n + m - 1)
        if n < m:
            for j in range(m):
                for i in range(n):
                    res[i + j] += a[i] * b[j]
        else:
            for i in range(n):
                for j in range(m):
                    res[i + j] += a[i] * b[j]
        return res
    
    return convolution_fft(a, b)

# ---------- Counting Path on Grid ----------
class CountingPathOnGrid:
    F = [[] for _ in range(25)]  # F[k] = ntt(0! + 1!x + ... + (2^k-1)!x^(2^k-1))
    
    @classmethod
    def make_f(cls, k: int):
        comb = Combination()
        comb.build(1 << k)
        cls.F[k] = [comb.fac(i) for i in range(1 << k)]
        butterfly(cls.F[k])
    
    @classmethod
    def solve(cls, H: int, W: int, w: List[ModInt]) -> List[ModInt]:
        assert len(w) == H + W - 1
        
        if not H or not W:
            return []
        
        comb = Combination()
        comb.build(H + W)
        
        logHW = (H + W - 1).bit_length()
        A = [ModInt(0)] * (1 << logHW)
        iz = ModInt(1 << logHW).inv()
        
        # Check if we need to use bottom paths
        use_sita = any(w[H - 1 + i] != 0 for i in range(1, W))
        
        res = [ModInt(0)] * (H + W - 1)
        
        # Left -> Up
        for i in range(H):
            A[i] = w[H - 1 - i] * comb.ifac(H - 1 - i)
        
        butterfly(A)
        if not cls.F[logHW]:
            cls.make_f(logHW)
        
        for i in range(1 << logHW):
            A[i] *= cls.F[logHW][i]
        
        butterfly_inv(A)
        for i in range(W):
            j = H - 1 + i
            res[i] += A[j] * iz * comb.ifac(i)
        
        # Bottom -> Right
        if use_sita:
            A = [ModInt(0)] * (1 << logHW)
            for i in range(1, W):
                A[i] = w[H - 1 + i] * comb.ifac(W - 1 - i)
            
            butterfly(A)
            for i in range(1 << logHW):
                A[i] *= cls.F[logHW][i]
            
            butterfly_inv(A)
            for i in range(H):
                res[H + W - 2 - i] += A[W - 1 + i] * iz * comb.ifac(i)
        
        # Left -> Right
        B = w[:H]
        C = [comb.comb(H + W - 2 - i, W - 1) for i in range(H)]
        conv = convolution_mod(B, C)
        for i in range(1, H):
            res[W - 1 + i] += conv[H - 1 + i]
        
        # Bottom -> Up
        if use_sita:
            B = [ModInt(0)] * W
            for i in range(W - 1):
                B[i] = w[H + W - 2 - i]
            
            C = [comb.comb(H + W - 2 - i, H - 1) for i in range(W)]
            conv = convolution_mod(B, C)
            for i in range(1, W):
                res[W - 1 - i] += conv[W - 1 + i]
        
        return res

# ---------- Main Algorithm ----------
def number_of_increasing_sequence(A: List[int], B: List[int]) -> List[ModInt]:
    N = len(A)
    if N == 0:
        return []
    
    D = [ModInt(0)] * (N + 1)
    L = [ModInt(0)] * B[-1]
    L[A[0]] = ModInt(1)
    C = [0] * N
    
    def dfs_inc(l: int, r: int, d: int):
        if d == C[r - 1]:
            return
        
        if r - l == 1:
            L[d] += D[l]
            for i in range(d + 1, C[l]):
                L[i] += L[i - 1]
            D[l] = L[C[l] - 1]
            return
        
        m = (l + r) // 2
        if l < m:
            dfs_inc(l, m, d)
        
        y = C[m]
        Hi = y - d
        Wi = r - m
        
        if Hi and Wi:
            w = [ModInt(0)] * (Hi + Wi - 1)
            for i in range(Hi):
                w[i] = L[y - 1 - i]
            for i in range(Wi):
                w[Hi - 1 + i] += D[m + i]
            
            w = CountingPathOnGrid.solve(Hi, Wi, w)
            for i in range(Wi):
                D[m + i] = w[i]
            for i in range(Hi):
                L[y - 1 - i] = w[Wi - 1 + i]
        
        if m < r:
            dfs_inc(m, r, y)
    
    def dfs_dec(l: int, r: int, d: int):
        if d == C[l]:
            return
        
        if r - l == 1:
            L[C[l] - 1] += D[l]
            for i in range(C[l] - 2, d - 1, -1):
                L[i] += L[i + 1]
            D[l] = L[d]
            for i in range(d, C[l]):
                D[l + 1] += L[i]
            return
        
        m = (l + r) // 2
        y = C[m - 1]
        
        if l < m:
            dfs_dec(l, m, y)
        
        Hi = y - d
        Wi = m - l
        
        if Hi and Wi:
            w = [ModInt(0)] * (Hi + Wi - 1)
            for i in range(Hi):
                w[i] = L[d + i]
            for i in range(Wi):
                w[Hi - 1 + i] += D[l + i]
            
            w = CountingPathOnGrid.solve(Hi, Wi, w)
            for i in range(Hi):
                L[d + i] = w[Wi - 1 + i]
            for i in range(Wi):
                D[l + i] = w[i]
        
        for i in range(C[m], y):
            D[m] += L[i]
        
        if m < r:
            dfs_dec(m, r, d)
    
    low_sum = ModInt(0)
    l = 0
    
    while l < N:
        r = l
        while r < N and (r == l or A[r] < B[l]):
            r += 1
        
        D[l] = low_sum
        low_sum = ModInt(0)
        
        for i in range(l, r):
            C[i] = B[l] - A[i] + A[l]
        
        # Reverse L[A[l]:B[l]]
        segment = L[A[l]:B[l]]
        segment.reverse()
        for i, val in enumerate(segment):
            L[A[l] + i] = val
        
        dfs_dec(l, r, A[l])
        
        for i in range(A[l], C[r - 1]):
            low_sum += L[i]
        
        # Reverse back
        segment = L[A[l]:B[l]]
        segment.reverse()
        for i, val in enumerate(segment):
            L[A[l] + i] = val
        
        for i in range(l, r):
            C[i] = B[i]
        
        dfs_inc(l, r, B[l])
        
        for i in range(B[l], A[r] if r < N else 0):
            low_sum += L[i]
        
        l = r
    
    return L

# ---------- Main Function ----------
def main():
    S = sys.stdin.readline().strip()
    B = []
    cnt = 0
    
    for i, ch in enumerate(S):
        if ch == 'A':
            B.append(i + 1 - cnt)
            cnt += 1
    
    N = len(B)
    A = [0] * N
    
    if N == 0:
        print(1)
        return
    
    for i in range(1, N):
        A[i] = max(A[i], A[i - 1])
    
    ans = number_of_increasing_sequence(A, B)
    s = ModInt(0)
    for i in range(A[-1], B[-1]):
        s += ans[i]
    
    print(s)

if __name__ == "__main__":
    main()
0