結果
| 問題 | No.3370 AB → BA |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2026-01-08 23:38:08 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
MLE
|
| 実行時間 | - |
| コード長 | 3,481 bytes |
| 記録 | |
| コンパイル時間 | 333 ms |
| コンパイル使用メモリ | 82,792 KB |
| 実行使用メモリ | 857,328 KB |
| 最終ジャッジ日時 | 2026-01-08 23:38:12 |
| 合計ジャッジ時間 | 4,267 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | MLE * 1 -- * 19 |
ソースコード
import sys
input = sys.stdin.readline
MOD = 998244353
PRIMITIVE_ROOT = 3
# ----------------- ModInt -----------------
class ModInt:
__slots__ = ['x']
def __init__(self,x):
self.x = x % MOD
def __add__(self,other): return ModInt(self.x + (other.x if isinstance(other,ModInt) else other))
def __sub__(self,other): return ModInt(self.x - (other.x if isinstance(other,ModInt) else other))
def __mul__(self,other): return ModInt(self.x * (other.x if isinstance(other,ModInt) else other))
def __pow__(self,p): return ModInt(pow(self.x,p,MOD))
def inv(self): return ModInt(pow(self.x,MOD-2,MOD))
def __truediv__(self,other): return self*self.inv() if isinstance(other,ModInt) else self*ModInt(other).inv()
def __iadd__(self,other): self.x=(self.x+(other.x if isinstance(other,ModInt) else other))%MOD; return self
def val(self): return self.x
# ----------------- NTT -----------------
def modinv(x): return pow(x,MOD-2,MOD)
def ntt(a,invert):
n=len(a); j=0
for i in range(1,n):
bit=n>>1
while j&bit:
j^=bit
bit>>=1
j^=bit
if i<j: a[i],a[j]=a[j],a[i]
length=2
while length<=n:
wlen=pow(PRIMITIVE_ROOT,(MOD-1)//length,MOD)
if invert: wlen=modinv(wlen)
for i in range(0,n,length):
w=1
for j in range(i,i+length//2):
u=a[j]; v=a[j+length//2]*w %MOD
a[j]=(u+v)%MOD
a[j+length//2]=(u-v+MOD)%MOD
w=w*wlen%MOD
length<<=1
if invert:
ninv=modinv(n)
for i in range(n): a[i]=a[i]*ninv%MOD
def convolution(a,b):
n=1
while n<len(a)+len(b)-1: n<<=1
fa=[x.x if isinstance(x,ModInt) else x for x in a]+[0]*(n-len(a))
fb=[x.x if isinstance(x,ModInt) else x for x in b]+[0]*(n-len(b))
ntt(fa,False)
ntt(fb,False)
for i in range(n): fa[i]=fa[i]*fb[i]%MOD
ntt(fa,True)
return [ModInt(x) for x in fa[:len(a)+len(b)-1]]
# ----------------- Binomial -----------------
class Binomial:
def __init__(self,MAX):
self.fact=[ModInt(1)]*(MAX+1)
self.invfact=[ModInt(1)]*(MAX+1)
for i in range(1,MAX+1): self.fact[i]=self.fact[i-1]*i
self.invfact[MAX]=self.fact[MAX].inv()
for i in range(MAX,0,-1): self.invfact[i-1]=self.invfact[i]*i
def C(self,n,k):
if n<0 or k<0 or k>n: return ModInt(0)
return self.fact[n]*self.invfact[k]*self.invfact[n-k]
# ----------------- Count increasing sequences -----------------
def count_inc_seq(A,C):
n=len(A)
if n==0: return []
if n==1: return [C[0]]*A[0]
m=n//2
LA,LC=A[:m],C[:m]
RA,RC=[A[i]-A[m-1] for i in range(m,n)],C[m:]
Lseq=count_inc_seq(LA,LC)
if not RA: return Lseq
max_len=sum(RA)+len(RA)+10
binom=Binomial(max_len)
tmp=[binom.C(RA[0]-1+i,i) for i in range(len(Lseq))]
conv_res=convolution(tmp,Lseq)
res=[ModInt(0)]*max(len(conv_res),RA[-1])
for i in range(len(conv_res)): res[i]=conv_res[i]
return res
# ----------------- Solve -----------------
def solve():
s=input().strip()
n=len(s)
a=s.count('A'); b=s.count('B')
if a==0 or b==0: print(1); return
lb,ub=[],[]
h=0
for c in s:
if c=='A': lb.append(0); ub.append(h+1)
else: h+=1
C=[ModInt(1) for _ in range(len(lb))]
tmp=count_inc_seq(ub,C)
ans=ModInt(0)
for x in tmp: ans+=x
print(ans.val())
if __name__=="__main__":
solve()