結果
| 問題 | No.3370 AB → BA |
| コンテスト | |
| ユーザー |
|
| 提出日時 | 2026-01-08 23:21:49 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
TLE
|
| 実行時間 | - |
| コード長 | 2,378 bytes |
| 記録 | |
| コンパイル時間 | 334 ms |
| コンパイル使用メモリ | 82,400 KB |
| 実行使用メモリ | 271,892 KB |
| 最終ジャッジ日時 | 2026-01-08 23:21:53 |
| 合計ジャッジ時間 | 4,216 ms |
|
ジャッジサーバーID (参考情報) |
judge5 / judge4 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | AC * 1 |
| other | TLE * 1 -- * 19 |
ソースコード
MOD = 998244353
class ModInt:
__slots__ = ['x']
def __init__(self, x):
self.x = x % MOD
def __add__(self, other):
return ModInt(self.x + other.x)
def __iadd__(self, other):
self.x = (self.x + other.x) % MOD
return self
def __sub__(self, other):
return ModInt(self.x - other.x)
def __mul__(self, other):
return ModInt(self.x * other.x)
def __repr__(self):
return str(self.x)
def __getitem__(self, i):
return self.x[i]
def __len__(self):
return 1 # dummy
def karatsuba(a, b):
n = len(a)
m = len(b)
if n==0 or m==0:
return []
if min(n,m)<=32:
res = [ModInt(0) for _ in range(n+m-1)]
for i in range(n):
for j in range(m):
res[i+j] += a[i]*b[j]
return res
k = max(n,m)//2
a0,a1 = a[:k],a[k:]
b0,b1 = b[:k],b[k:]
z0 = karatsuba(a0,b0)
z2 = karatsuba(a1,b1)
a0a1 = [x+y for x,y in zip(a0 + [ModInt(0)]*(len(a1)-len(a0)), a1)]
b0b1 = [x+y for x,y in zip(b0 + [ModInt(0)]*(len(b1)-len(b0)), b1)]
z1 = karatsuba(a0a1, b0b1)
# z1 - z0 - z2
for i in range(len(z0)):
z1[i] -= z0[i]
for i in range(len(z2)):
z1[i] -= z2[i]
res = [ModInt(0)]*(n+m-1)
for i,v in enumerate(z0): res[i] += v
for i,v in enumerate(z1): res[i+k] += v
for i,v in enumerate(z2): res[i+2*k] += v
return res
def calc(a, b, f, g):
n = len(a)
for i in range(1,n):
a[i] = max(a[i-1], a[i])
b[i] = min(b[i-1]+len(g)-1, b[i])
for i in reversed(range(1,n)):
a[i-1] = max(a[i-1], a[i]-len(g)+1)
b[i-1] = min(b[i-1], b[i])
if any(x>=y for x,y in zip(a,b)):
return []
poly = f[:]
for i in range(n):
poly = karatsuba(poly, g)
l,r = a[i], b[i]
poly = poly[l:r]
return poly
# ---------- main ----------
s = input().strip()
h = 1
a = []
b = []
for c in s:
if c=='B':
h += 1
else:
a.append(0)
b.append(h)
if not a:
print(1)
exit()
n = len(b)
len_total = n + b[-1]
l = [0]*len_total
r = [10**9]*len_total
l[len_total-1] = b[-1]-1
for i,(aa,bb) in enumerate(zip(a,b)):
l[i+aa] = aa
r[i+bb] = bb
f0 = [ModInt(1)]
g = [ModInt(1), ModInt(1)]
ans_poly = calc(l,r,f0,g)
ans = ModInt(0)
for x in ans_poly:
ans += x
print(ans)