結果

問題 No.3370 AB → BA
コンテスト
ユーザー Aralov Otabek
提出日時 2026-01-08 23:33:58
言語 PyPy3
(7.3.17)
結果
MLE  
実行時間 -
コード長 4,042 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 201 ms
コンパイル使用メモリ 82,948 KB
実行使用メモリ 857,920 KB
最終ジャッジ日時 2026-01-08 23:34:09
合計ジャッジ時間 4,645 ms
ジャッジサーバーID
(参考情報)
judge2 / judge5
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other MLE * 1 -- * 19
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys
input = sys.stdin.readline

MOD = 998244353
PRIMITIVE_ROOT = 3

# ----------------- ModInt -----------------
class ModInt:
    __slots__ = ['x']
    def __init__(self,x):
        self.x = x % MOD
    def __add__(self,other):
        return ModInt(self.x + (other.x if isinstance(other,ModInt) else other))
    def __sub__(self,other):
        return ModInt(self.x - (other.x if isinstance(other,ModInt) else other))
    def __mul__(self,other):
        return ModInt(self.x * (other.x if isinstance(other,ModInt) else other))
    def __pow__(self,power):
        return ModInt(pow(self.x,power,MOD))
    def inv(self):
        return ModInt(pow(self.x,MOD-2,MOD))
    def __truediv__(self,other):
        return self*self.inv() if isinstance(other,ModInt) else self*ModInt(other).inv()
    def __iadd__(self,other):
        self.x = (self.x + (other.x if isinstance(other,ModInt) else other)) % MOD
        return self
    def __repr__(self):
        return str(self.x)
    def val(self):
        return self.x

# ----------------- NTT / Convolution -----------------
def modinv(x): return pow(x,MOD-2,MOD)

def ntt(a,invert):
    n = len(a)
    j = 0
    for i in range(1,n):
        bit = n >> 1
        while j & bit:
            j ^= bit
            bit >>= 1
        j ^= bit
        if i < j:
            a[i],a[j] = a[j],a[i]
    length = 2
    while length <= n:
        wlen = pow(PRIMITIVE_ROOT,(MOD-1)//length,MOD)
        if invert:
            wlen = modinv(wlen)
        for i in range(0,n,length):
            w = 1
            for j in range(i,i+length//2):
                u = a[j]
                v = a[j+length//2]*w % MOD
                a[j] = (u+v)%MOD
                a[j+length//2] = (u-v+MOD)%MOD
                w = w*wlen%MOD
        length <<= 1
    if invert:
        ninv = modinv(n)
        for i in range(n):
            a[i] = a[i]*ninv%MOD

def convolution(a,b):
    n = 1
    while n < len(a)+len(b)-1: n <<= 1
    fa = [x.x if isinstance(x,ModInt) else x for x in a]+[0]*(n-len(a))
    fb = [x.x if isinstance(x,ModInt) else x for x in b]+[0]*(n-len(b))
    ntt(fa,False)
    ntt(fb,False)
    for i in range(n): fa[i] = fa[i]*fb[i]%MOD
    ntt(fa,True)
    return [ModInt(x) for x in fa[:len(a)+len(b)-1]]

# ----------------- Binomial -----------------
class Binomial:
    def __init__(self, MAX):
        self.fact = [ModInt(1)]*(MAX+1)
        self.invfact = [ModInt(1)]*(MAX+1)
        for i in range(1,MAX+1):
            self.fact[i] = self.fact[i-1]*i
        self.invfact[MAX] = self.fact[MAX].inv()
        for i in range(MAX,0,-1):
            self.invfact[i-1] = self.invfact[i]*i
    def C(self,n,k):
        if n<0 or k<0 or k>n: return ModInt(0)
        return self.fact[n]*self.invfact[k]*self.invfact[n-k]

# ----------------- Count increasing sequences (D&C + convolution) -----------------
def count_increase_sequences_with_upper_bounds(A,C):
    n = len(A)
    if n==0: return []
    if n==1:
        return [C[0]]*A[0]
    m = n//2
    LA,LC = A[:m],C[:m]
    RA,RC = [A[i]-A[m-1] for i in range(m,n)],C[m:]
    Lseq = count_increase_sequences_with_upper_bounds(LA,LC)
    # Right convolution
    binom = Binomial(sum(RA)+len(RA)+10)
    tmp = [binom.C(RA[0]-1+i,i) for i in range(len(Lseq))]
    conv_res = convolution(tmp,Lseq)
    # Naive add (sum if needed)
    res = [ModInt(0)]*(max(len(conv_res),RA[-1]))
    for i in range(len(conv_res)):
        res[i] = conv_res[i]
    return res

# ----------------- Solve -----------------
def solve():
    s = input().strip()
    n = len(s)
    a = s.count('A')
    b = s.count('B')
    if a==0 or b==0:
        print(1)
        return
    lb,ub = [],[]
    h = 0
    for c in s:
        if c=='A':
            lb.append(0)
            ub.append(h+1)
        else:
            h += 1
    # C vector = all ones
    C = [ModInt(1) for _ in range(len(lb))]
    tmp = count_increase_sequences_with_upper_bounds(ub,C)
    ans = ModInt(0)
    for x in tmp:
        ans += x
    print(ans.val())

if __name__=="__main__":
    solve()
0