結果
| 問題 | No.3370 AB → BA |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2026-01-08 23:30:02 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 2,787 bytes |
| 記録 | |
| コンパイル時間 | 301 ms |
| コンパイル使用メモリ | 82,272 KB |
| 実行使用メモリ | 94,132 KB |
| 最終ジャッジ日時 | 2026-01-08 23:30:07 |
| 合計ジャッジ時間 | 4,367 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | TLE * 1 -- * 19 |
ソースコード
import sys
input = sys.stdin.readline
MOD = 998244353
class ModInt:
def __init__(self, x):
self.x = x % MOD
def __add__(self, other):
return ModInt(self.x + (other.x if isinstance(other, ModInt) else other))
def __sub__(self, other):
return ModInt(self.x - (other.x if isinstance(other, ModInt) else other))
def __mul__(self, other):
return ModInt(self.x * (other.x if isinstance(other, ModInt) else other))
def __pow__(self, power):
return ModInt(pow(self.x, power, MOD))
def inv(self):
return ModInt(pow(self.x, MOD-2, MOD))
def __truediv__(self, other):
return self * (other.inv() if isinstance(other, ModInt) else ModInt(other).inv())
def __iadd__(self, other):
self.x = (self.x + (other.x if isinstance(other, ModInt) else other)) % MOD
return self
def __repr__(self):
return str(self.x)
def val(self):
return self.x
# Naive convolution (small n <= 200 uchun)
def conv_naive(a,b):
n = len(a)
m = len(b)
res = [ModInt(0) for _ in range(n+m-1)]
for i in range(n):
for j in range(m):
res[i+j] += a[i]*b[j]
return res
# Factorial table
class Binomial:
def __init__(self, MAX):
self.fact = [ModInt(1)]
self.invfact = [ModInt(1)]
for i in range(1, MAX+1):
self.fact.append(self.fact[-1]*i)
self.invfact = [ModInt(1)]*(MAX+1)
self.invfact[MAX] = self.fact[MAX].inv()
for i in range(MAX,0,-1):
self.invfact[i-1] = self.invfact[i]*i
def C(self,n,k):
if n<0 or k<0 or k>n: return ModInt(0)
return self.fact[n]*self.invfact[k]*self.invfact[n-k]
# Naive count_increase_sequences_with_upper_bounds
def count_increase_sequences_with_upper_bounds(A,C):
n = len(A)
if n==1:
return [C[0] for _ in range(A[0])]
dp = [ModInt(0) for _ in range(A[0])]
for i in range(A[0]):
dp[i] = C[0]
for i in range(1,n):
new_dp = [ModInt(0) for _ in range(A[i])]
for j in range(A[i]):
for k in range(min(j+1,len(dp))):
new_dp[j] += dp[k]
new_dp[j] += C[i] if j==0 else ModInt(0)
dp = new_dp
return dp
def solve():
s = input().strip()
n = len(s)
a = s.count('A')
b = s.count('B')
if a==0 or b==0:
print(1)
return
lb, ub = [], []
h = 0
for c in s:
if c=='A':
lb.append(0)
ub.append(h+1)
else:
h += 1
# C vector = all ones
C = [ModInt(1) for _ in range(len(lb))]
tmp = count_increase_sequences_with_upper_bounds(ub,C)
ans = ModInt(0)
for x in tmp:
ans += x
print(ans.val())
if __name__=="__main__":
solve()