結果

問題 No.3370 AB → BA
コンテスト
ユーザー Aralov Otabek
提出日時 2026-01-08 23:38:08
言語 PyPy3
(7.3.17)
結果
MLE  
実行時間 -
コード長 3,481 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 333 ms
コンパイル使用メモリ 82,792 KB
実行使用メモリ 857,328 KB
最終ジャッジ日時 2026-01-08 23:38:12
合計ジャッジ時間 4,267 ms
ジャッジサーバーID
(参考情報)
judge2 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other MLE * 1 -- * 19
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

import sys
input = sys.stdin.readline

MOD = 998244353
PRIMITIVE_ROOT = 3

# ----------------- ModInt -----------------
class ModInt:
    __slots__ = ['x']
    def __init__(self,x):
        self.x = x % MOD
    def __add__(self,other): return ModInt(self.x + (other.x if isinstance(other,ModInt) else other))
    def __sub__(self,other): return ModInt(self.x - (other.x if isinstance(other,ModInt) else other))
    def __mul__(self,other): return ModInt(self.x * (other.x if isinstance(other,ModInt) else other))
    def __pow__(self,p): return ModInt(pow(self.x,p,MOD))
    def inv(self): return ModInt(pow(self.x,MOD-2,MOD))
    def __truediv__(self,other): return self*self.inv() if isinstance(other,ModInt) else self*ModInt(other).inv()
    def __iadd__(self,other): self.x=(self.x+(other.x if isinstance(other,ModInt) else other))%MOD; return self
    def val(self): return self.x

# ----------------- NTT -----------------
def modinv(x): return pow(x,MOD-2,MOD)

def ntt(a,invert):
    n=len(a); j=0
    for i in range(1,n):
        bit=n>>1
        while j&bit:
            j^=bit
            bit>>=1
        j^=bit
        if i<j: a[i],a[j]=a[j],a[i]
    length=2
    while length<=n:
        wlen=pow(PRIMITIVE_ROOT,(MOD-1)//length,MOD)
        if invert: wlen=modinv(wlen)
        for i in range(0,n,length):
            w=1
            for j in range(i,i+length//2):
                u=a[j]; v=a[j+length//2]*w %MOD
                a[j]=(u+v)%MOD
                a[j+length//2]=(u-v+MOD)%MOD
                w=w*wlen%MOD
        length<<=1
    if invert:
        ninv=modinv(n)
        for i in range(n): a[i]=a[i]*ninv%MOD

def convolution(a,b):
    n=1
    while n<len(a)+len(b)-1: n<<=1
    fa=[x.x if isinstance(x,ModInt) else x for x in a]+[0]*(n-len(a))
    fb=[x.x if isinstance(x,ModInt) else x for x in b]+[0]*(n-len(b))
    ntt(fa,False)
    ntt(fb,False)
    for i in range(n): fa[i]=fa[i]*fb[i]%MOD
    ntt(fa,True)
    return [ModInt(x) for x in fa[:len(a)+len(b)-1]]

# ----------------- Binomial -----------------
class Binomial:
    def __init__(self,MAX):
        self.fact=[ModInt(1)]*(MAX+1)
        self.invfact=[ModInt(1)]*(MAX+1)
        for i in range(1,MAX+1): self.fact[i]=self.fact[i-1]*i
        self.invfact[MAX]=self.fact[MAX].inv()
        for i in range(MAX,0,-1): self.invfact[i-1]=self.invfact[i]*i
    def C(self,n,k):
        if n<0 or k<0 or k>n: return ModInt(0)
        return self.fact[n]*self.invfact[k]*self.invfact[n-k]

# ----------------- Count increasing sequences -----------------
def count_inc_seq(A,C):
    n=len(A)
    if n==0: return []
    if n==1: return [C[0]]*A[0]
    m=n//2
    LA,LC=A[:m],C[:m]
    RA,RC=[A[i]-A[m-1] for i in range(m,n)],C[m:]
    Lseq=count_inc_seq(LA,LC)
    if not RA: return Lseq
    max_len=sum(RA)+len(RA)+10
    binom=Binomial(max_len)
    tmp=[binom.C(RA[0]-1+i,i) for i in range(len(Lseq))]
    conv_res=convolution(tmp,Lseq)
    res=[ModInt(0)]*max(len(conv_res),RA[-1])
    for i in range(len(conv_res)): res[i]=conv_res[i]
    return res

# ----------------- Solve -----------------
def solve():
    s=input().strip()
    n=len(s)
    a=s.count('A'); b=s.count('B')
    if a==0 or b==0: print(1); return
    lb,ub=[],[]
    h=0
    for c in s:
        if c=='A': lb.append(0); ub.append(h+1)
        else: h+=1
    C=[ModInt(1) for _ in range(len(lb))]
    tmp=count_inc_seq(ub,C)
    ans=ModInt(0)
    for x in tmp: ans+=x
    print(ans.val())

if __name__=="__main__":
    solve()
0