結果
| 問題 | No.3370 AB → BA |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2026-01-08 22:59:38 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 15,982 bytes |
| 記録 | |
| コンパイル時間 | 285 ms |
| コンパイル使用メモリ | 82,296 KB |
| 実行使用メモリ | 276,680 KB |
| 最終ジャッジ日時 | 2026-01-08 22:59:48 |
| 合計ジャッジ時間 | 4,591 ms |
|
ジャッジサーバーID (参考情報) |
judge4 / judge2 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | TLE * 3 -- * 17 |
ソースコード
import sys
import math
from typing import List, Tuple, Optional
import sys
MOD = 998244353
PRIMITIVE_ROOT = 3
# ---------- ModInt class ----------
class ModInt:
def __init__(self, x=0):
if isinstance(x, ModInt):
self.x = x.x
else:
self.x = x % MOD
if self.x < 0:
self.x += MOD
def __int__(self):
return self.x
def __str__(self):
return str(self.x)
def __repr__(self):
return f"ModInt({self.x})"
def __eq__(self, other):
if isinstance(other, ModInt):
return self.x == other.x
return self.x == (other % MOD)
def __ne__(self, other):
return not self.__eq__(other)
def __add__(self, other):
if isinstance(other, ModInt):
return ModInt(self.x + other.x)
return ModInt(self.x + other)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
if isinstance(other, ModInt):
return ModInt(self.x - other.x)
return ModInt(self.x - other)
def __rsub__(self, other):
if isinstance(other, ModInt):
return ModInt(other.x - self.x)
return ModInt(other - self.x)
def __mul__(self, other):
if isinstance(other, ModInt):
return ModInt(self.x * other.x)
return ModInt(self.x * other)
def __rmul__(self, other):
return self.__mul__(other)
def __truediv__(self, other):
if isinstance(other, ModInt):
return self * other.inv()
return self * ModInt(other).inv()
def __rtruediv__(self, other):
if isinstance(other, ModInt):
return other * self.inv()
return ModInt(other) * self.inv()
def __pow__(self, exp):
if exp == 0:
return ModInt(1)
if exp < 0:
return self.inv() ** (-exp)
res = ModInt(1)
base = self
while exp:
if exp & 1:
res *= base
base *= base
exp >>= 1
return res
def inv(self):
# Fermat's little theorem
return self ** (MOD - 2)
def __neg__(self):
return ModInt(-self.x)
def __hash__(self):
return hash(self.x)
def __lt__(self, other):
if isinstance(other, ModInt):
return self.x < other.x
return self.x < other
def __le__(self, other):
if isinstance(other, ModInt):
return self.x <= other.x
return self.x <= other
def __gt__(self, other):
return not self.__le__(other)
def __ge__(self, other):
return not self.__lt__(other)
# ---------- Combination class ----------
class Combination:
def __init__(self):
self.fact = [ModInt(1)]
self.ifact = [ModInt(1)]
self.inv_list = [ModInt(0)]
self.N = 0
def build(self, n):
if n <= self.N:
return
self.fact = self.fact[:] + [ModInt(0)] * (n - self.N)
self.ifact = self.ifact[:] + [ModInt(0)] * (n - self.N)
self.inv_list = self.inv_list[:] + [ModInt(0)] * (n - self.N)
for i in range(max(1, self.N), n):
self.fact[i] = self.fact[i-1] * i
self.ifact[n-1] = self.fact[n-1].inv()
for i in range(n-1, max(1, self.N), -1):
self.ifact[i-1] = self.ifact[i] * i
self.inv_list[i] = self.ifact[i] * self.fact[i-1]
self.N = n
def fac(self, k):
if k >= self.N:
self.build(k+1)
return self.fact[k]
def ifac(self, k):
if k >= self.N:
self.build(k+1)
return self.ifact[k]
def inv(self, k):
if k >= self.N:
self.build(k+1)
return self.inv_list[k]
def comb(self, a, b):
if a < b or b < 0:
return ModInt(0)
return self.fac(a) * self.ifac(a-b) * self.ifac(b)
def perm(self, a, b):
if a < b or b < 0:
return ModInt(0)
return self.fac(a) * self.ifac(a-b)
# ---------- FFT Info ----------
class FFTInfo:
def __init__(self, g=PRIMITIVE_ROOT):
self.rank2 = (MOD - 1).bit_length() - 1
self.root = [ModInt(0)] * (self.rank2 + 2)
self.iroot = [ModInt(0)] * (self.rank2 + 2)
self.root[self.rank2] = ModInt(g) ** ((MOD - 1) >> self.rank2)
self.iroot[self.rank2] = self.root[self.rank2].inv()
for i in range(self.rank2 - 1, -1, -1):
self.root[i] = self.root[i+1] * self.root[i+1]
self.iroot[i] = self.iroot[i+1] * self.iroot[i+1]
max_len = max(0, self.rank2 - 2 + 1)
self.rate2 = [ModInt(0)] * max_len
self.irate2 = [ModInt(0)] * max_len
max_len3 = max(0, self.rank2 - 3 + 1)
self.rate3 = [ModInt(0)] * max_len3
self.irate3 = [ModInt(0)] * max_len3
prod = ModInt(1)
iprod = ModInt(1)
for i in range(max_len):
self.rate2[i] = self.root[i+2] * prod
self.irate2[i] = self.iroot[i+2] * iprod
prod *= self.iroot[i+2]
iprod *= self.root[i+2]
prod = ModInt(1)
iprod = ModInt(1)
for i in range(max_len3):
self.rate3[i] = self.root[i+3] * prod
self.irate3[i] = self.iroot[i+3] * iprod
prod *= self.iroot[i+3]
iprod *= self.root[i+3]
# ---------- FFT Functions ----------
def butterfly(a: List[ModInt]):
n = len(a)
h = (n - 1).bit_length() - 1
info = FFTInfo()
length = 0
while length < h:
if h - length == 1:
p = 1 << (h - length - 1)
rot = ModInt(1)
for s in range(1 << length):
offset = s << (h - length)
for i in range(p):
l = a[i + offset]
r = a[i + offset + p] * rot
a[i + offset] = l + r
a[i + offset + p] = l - r
if s + 1 != (1 << length):
rot *= info.rate2[((~s) & (~s + 1)).bit_length() - 1]
length += 1
else:
p = 1 << (h - length - 2)
rot = ModInt(1)
imag = info.root[2]
for s in range(1 << length):
rot2 = rot * rot
rot3 = rot2 * rot
offset = s << (h - length)
for i in range(p):
a0 = a[i + offset]
a1 = a[i + offset + p] * rot
a2 = a[i + offset + 2 * p] * rot2
a3 = a[i + offset + 3 * p] * rot3
a1na3imag = (a1 - a3) * imag
a[i + offset] = a0 + a2 + a1 + a3
a[i + offset + p] = a0 + a2 - a1 - a3
a[i + offset + 2 * p] = a0 - a2 + a1na3imag
a[i + offset + 3 * p] = a0 - a2 - a1na3imag
if s + 1 != (1 << length):
rot *= info.rate3[((~s) & (~s + 1)).bit_length() - 1]
length += 2
def butterfly_inv(a: List[ModInt]):
n = len(a)
h = (n - 1).bit_length() - 1
info = FFTInfo()
length = h
while length:
if length == 1:
p = 1 << (h - length)
irot = ModInt(1)
for s in range(1 << (length - 1)):
offset = s << (h - length + 1)
for i in range(p):
l = a[i + offset]
r = a[i + offset + p]
a[i + offset] = l + r
a[i + offset + p] = (l - r) * irot
if s + 1 != (1 << (length - 1)):
irot *= info.irate2[((~s) & (~s + 1)).bit_length() - 1]
length -= 1
else:
p = 1 << (h - length)
irot = ModInt(1)
iimag = info.iroot[2]
for s in range(1 << (length - 2)):
irot2 = irot * irot
irot3 = irot2 * irot
offset = s << (h - length + 2)
for i in range(p):
a0 = a[i + offset]
a1 = a[i + offset + p]
a2 = a[i + offset + 2 * p]
a3 = a[i + offset + 3 * p]
a2na3iimag = (a2 - a3) * iimag
a[i + offset] = a0 + a1 + a2 + a3
a[i + offset + p] = (a0 - a1 + a2na3iimag) * irot
a[i + offset + 2 * p] = (a0 + a1 - a2 - a3) * irot2
a[i + offset + 3 * p] = (a0 - a1 - a2na3iimag) * irot3
if s + 1 != (1 << (length - 2)):
irot *= info.irate3[((~s) & (~s + 1)).bit_length() - 1]
length -= 2
def convolution_fft(a: List[ModInt], b: List[ModInt]) -> List[ModInt]:
n, m = len(a), len(b)
z = 1 << (n + m - 1 - 1).bit_length()
a = a + [ModInt(0)] * (z - n)
b = b + [ModInt(0)] * (z - m)
butterfly(a)
butterfly(b)
for i in range(z):
a[i] *= b[i]
butterfly_inv(a)
iz = ModInt(z).inv()
res = [a[i] * iz for i in range(n + m - 1)]
return res
def convolution_mod(a: List[ModInt], b: List[ModInt]) -> List[ModInt]:
if not a or not b:
return []
n, m = len(a), len(b)
if min(n, m) <= 60:
# Naive convolution for small cases
res = [ModInt(0)] * (n + m - 1)
if n < m:
for j in range(m):
for i in range(n):
res[i + j] += a[i] * b[j]
else:
for i in range(n):
for j in range(m):
res[i + j] += a[i] * b[j]
return res
return convolution_fft(a, b)
# ---------- Counting Path on Grid ----------
class CountingPathOnGrid:
F = [[] for _ in range(25)] # F[k] = ntt(0! + 1!x + ... + (2^k-1)!x^(2^k-1))
@classmethod
def make_f(cls, k: int):
comb = Combination()
comb.build(1 << k)
cls.F[k] = [comb.fac(i) for i in range(1 << k)]
butterfly(cls.F[k])
@classmethod
def solve(cls, H: int, W: int, w: List[ModInt]) -> List[ModInt]:
assert len(w) == H + W - 1
if not H or not W:
return []
comb = Combination()
comb.build(H + W)
logHW = (H + W - 1).bit_length()
A = [ModInt(0)] * (1 << logHW)
iz = ModInt(1 << logHW).inv()
# Check if we need to use bottom paths
use_sita = any(w[H - 1 + i] != 0 for i in range(1, W))
res = [ModInt(0)] * (H + W - 1)
# Left -> Up
for i in range(H):
A[i] = w[H - 1 - i] * comb.ifac(H - 1 - i)
butterfly(A)
if not cls.F[logHW]:
cls.make_f(logHW)
for i in range(1 << logHW):
A[i] *= cls.F[logHW][i]
butterfly_inv(A)
for i in range(W):
j = H - 1 + i
res[i] += A[j] * iz * comb.ifac(i)
# Bottom -> Right
if use_sita:
A = [ModInt(0)] * (1 << logHW)
for i in range(1, W):
A[i] = w[H - 1 + i] * comb.ifac(W - 1 - i)
butterfly(A)
for i in range(1 << logHW):
A[i] *= cls.F[logHW][i]
butterfly_inv(A)
for i in range(H):
res[H + W - 2 - i] += A[W - 1 + i] * iz * comb.ifac(i)
# Left -> Right
B = w[:H]
C = [comb.comb(H + W - 2 - i, W - 1) for i in range(H)]
conv = convolution_mod(B, C)
for i in range(1, H):
res[W - 1 + i] += conv[H - 1 + i]
# Bottom -> Up
if use_sita:
B = [ModInt(0)] * W
for i in range(W - 1):
B[i] = w[H + W - 2 - i]
C = [comb.comb(H + W - 2 - i, H - 1) for i in range(W)]
conv = convolution_mod(B, C)
for i in range(1, W):
res[W - 1 - i] += conv[W - 1 + i]
return res
# ---------- Main Algorithm ----------
def number_of_increasing_sequence(A: List[int], B: List[int]) -> List[ModInt]:
N = len(A)
if N == 0:
return []
D = [ModInt(0)] * (N + 1)
L = [ModInt(0)] * B[-1]
L[A[0]] = ModInt(1)
C = [0] * N
def dfs_inc(l: int, r: int, d: int):
if d == C[r - 1]:
return
if r - l == 1:
L[d] += D[l]
for i in range(d + 1, C[l]):
L[i] += L[i - 1]
D[l] = L[C[l] - 1]
return
m = (l + r) // 2
if l < m:
dfs_inc(l, m, d)
y = C[m]
Hi = y - d
Wi = r - m
if Hi and Wi:
w = [ModInt(0)] * (Hi + Wi - 1)
for i in range(Hi):
w[i] = L[y - 1 - i]
for i in range(Wi):
w[Hi - 1 + i] += D[m + i]
w = CountingPathOnGrid.solve(Hi, Wi, w)
for i in range(Wi):
D[m + i] = w[i]
for i in range(Hi):
L[y - 1 - i] = w[Wi - 1 + i]
if m < r:
dfs_inc(m, r, y)
def dfs_dec(l: int, r: int, d: int):
if d == C[l]:
return
if r - l == 1:
L[C[l] - 1] += D[l]
for i in range(C[l] - 2, d - 1, -1):
L[i] += L[i + 1]
D[l] = L[d]
for i in range(d, C[l]):
D[l + 1] += L[i]
return
m = (l + r) // 2
y = C[m - 1]
if l < m:
dfs_dec(l, m, y)
Hi = y - d
Wi = m - l
if Hi and Wi:
w = [ModInt(0)] * (Hi + Wi - 1)
for i in range(Hi):
w[i] = L[d + i]
for i in range(Wi):
w[Hi - 1 + i] += D[l + i]
w = CountingPathOnGrid.solve(Hi, Wi, w)
for i in range(Hi):
L[d + i] = w[Wi - 1 + i]
for i in range(Wi):
D[l + i] = w[i]
for i in range(C[m], y):
D[m] += L[i]
if m < r:
dfs_dec(m, r, d)
low_sum = ModInt(0)
l = 0
while l < N:
r = l
while r < N and (r == l or A[r] < B[l]):
r += 1
D[l] = low_sum
low_sum = ModInt(0)
for i in range(l, r):
C[i] = B[l] - A[i] + A[l]
# Reverse L[A[l]:B[l]]
segment = L[A[l]:B[l]]
segment.reverse()
for i, val in enumerate(segment):
L[A[l] + i] = val
dfs_dec(l, r, A[l])
for i in range(A[l], C[r - 1]):
low_sum += L[i]
# Reverse back
segment = L[A[l]:B[l]]
segment.reverse()
for i, val in enumerate(segment):
L[A[l] + i] = val
for i in range(l, r):
C[i] = B[i]
dfs_inc(l, r, B[l])
for i in range(B[l], A[r] if r < N else 0):
low_sum += L[i]
l = r
return L
# ---------- Main Function ----------
def main():
S = sys.stdin.readline().strip()
B = []
cnt = 0
for i, ch in enumerate(S):
if ch == 'A':
B.append(i + 1 - cnt)
cnt += 1
N = len(B)
A = [0] * N
if N == 0:
print(1)
return
for i in range(1, N):
A[i] = max(A[i], A[i - 1])
ans = number_of_increasing_sequence(A, B)
s = ModInt(0)
for i in range(A[-1], B[-1]):
s += ans[i]
print(s)
if __name__ == "__main__":
main()