結果
| 問題 | No.3370 AB → BA |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2026-01-08 23:33:58 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
MLE
|
| 実行時間 | - |
| コード長 | 4,042 bytes |
| 記録 | |
| コンパイル時間 | 201 ms |
| コンパイル使用メモリ | 82,948 KB |
| 実行使用メモリ | 857,920 KB |
| 最終ジャッジ日時 | 2026-01-08 23:34:09 |
| 合計ジャッジ時間 | 4,645 ms |
|
ジャッジサーバーID (参考情報) |
judge2 / judge5 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | MLE * 1 -- * 19 |
ソースコード
import sys
input = sys.stdin.readline
MOD = 998244353
PRIMITIVE_ROOT = 3
# ----------------- ModInt -----------------
class ModInt:
__slots__ = ['x']
def __init__(self,x):
self.x = x % MOD
def __add__(self,other):
return ModInt(self.x + (other.x if isinstance(other,ModInt) else other))
def __sub__(self,other):
return ModInt(self.x - (other.x if isinstance(other,ModInt) else other))
def __mul__(self,other):
return ModInt(self.x * (other.x if isinstance(other,ModInt) else other))
def __pow__(self,power):
return ModInt(pow(self.x,power,MOD))
def inv(self):
return ModInt(pow(self.x,MOD-2,MOD))
def __truediv__(self,other):
return self*self.inv() if isinstance(other,ModInt) else self*ModInt(other).inv()
def __iadd__(self,other):
self.x = (self.x + (other.x if isinstance(other,ModInt) else other)) % MOD
return self
def __repr__(self):
return str(self.x)
def val(self):
return self.x
# ----------------- NTT / Convolution -----------------
def modinv(x): return pow(x,MOD-2,MOD)
def ntt(a,invert):
n = len(a)
j = 0
for i in range(1,n):
bit = n >> 1
while j & bit:
j ^= bit
bit >>= 1
j ^= bit
if i < j:
a[i],a[j] = a[j],a[i]
length = 2
while length <= n:
wlen = pow(PRIMITIVE_ROOT,(MOD-1)//length,MOD)
if invert:
wlen = modinv(wlen)
for i in range(0,n,length):
w = 1
for j in range(i,i+length//2):
u = a[j]
v = a[j+length//2]*w % MOD
a[j] = (u+v)%MOD
a[j+length//2] = (u-v+MOD)%MOD
w = w*wlen%MOD
length <<= 1
if invert:
ninv = modinv(n)
for i in range(n):
a[i] = a[i]*ninv%MOD
def convolution(a,b):
n = 1
while n < len(a)+len(b)-1: n <<= 1
fa = [x.x if isinstance(x,ModInt) else x for x in a]+[0]*(n-len(a))
fb = [x.x if isinstance(x,ModInt) else x for x in b]+[0]*(n-len(b))
ntt(fa,False)
ntt(fb,False)
for i in range(n): fa[i] = fa[i]*fb[i]%MOD
ntt(fa,True)
return [ModInt(x) for x in fa[:len(a)+len(b)-1]]
# ----------------- Binomial -----------------
class Binomial:
def __init__(self, MAX):
self.fact = [ModInt(1)]*(MAX+1)
self.invfact = [ModInt(1)]*(MAX+1)
for i in range(1,MAX+1):
self.fact[i] = self.fact[i-1]*i
self.invfact[MAX] = self.fact[MAX].inv()
for i in range(MAX,0,-1):
self.invfact[i-1] = self.invfact[i]*i
def C(self,n,k):
if n<0 or k<0 or k>n: return ModInt(0)
return self.fact[n]*self.invfact[k]*self.invfact[n-k]
# ----------------- Count increasing sequences (D&C + convolution) -----------------
def count_increase_sequences_with_upper_bounds(A,C):
n = len(A)
if n==0: return []
if n==1:
return [C[0]]*A[0]
m = n//2
LA,LC = A[:m],C[:m]
RA,RC = [A[i]-A[m-1] for i in range(m,n)],C[m:]
Lseq = count_increase_sequences_with_upper_bounds(LA,LC)
# Right convolution
binom = Binomial(sum(RA)+len(RA)+10)
tmp = [binom.C(RA[0]-1+i,i) for i in range(len(Lseq))]
conv_res = convolution(tmp,Lseq)
# Naive add (sum if needed)
res = [ModInt(0)]*(max(len(conv_res),RA[-1]))
for i in range(len(conv_res)):
res[i] = conv_res[i]
return res
# ----------------- Solve -----------------
def solve():
s = input().strip()
n = len(s)
a = s.count('A')
b = s.count('B')
if a==0 or b==0:
print(1)
return
lb,ub = [],[]
h = 0
for c in s:
if c=='A':
lb.append(0)
ub.append(h+1)
else:
h += 1
# C vector = all ones
C = [ModInt(1) for _ in range(len(lb))]
tmp = count_increase_sequences_with_upper_bounds(ub,C)
ans = ModInt(0)
for x in tmp:
ans += x
print(ans.val())
if __name__=="__main__":
solve()