結果
| 問題 | No.3439 [Cherry 8th Tune] どの頂点にいた頃に戻りたいのか? |
| コンテスト | |
| ユーザー |
👑 Kazun
|
| 提出日時 | 2025-11-16 21:46:20 |
| 言語 | PyPy3 (7.3.17) |
| 結果 |
RE
|
| 実行時間 | - |
| コード長 | 21,475 bytes |
| 記録 | |
| コンパイル時間 | 329 ms |
| コンパイル使用メモリ | 82,536 KB |
| 実行使用メモリ | 75,324 KB |
| 最終ジャッジ日時 | 2026-01-23 21:04:41 |
| 合計ジャッジ時間 | 6,807 ms |
|
ジャッジサーバーID (参考情報) |
judge3 / judge1 |
(要ログイン)
| ファイルパターン | 結果 |
|---|---|
| sample | RE * 2 |
| other | RE * 37 |
ソースコード
# Reference: https://qiita.com/tatyam/items/492c70ac4c955c055602
# ※ 計算量が O(sqrt(N)) per query なので, 過度な期待はしないこと.
from bisect import bisect_left, bisect_right
from typing import Generic, Iterable, Iterator, TypeVar
T = TypeVar('T')
class Sorted_Set(Generic[T]):
BUCKET_RATIO=50
REBUILD_RATIO=170
def __init__(self, A: Iterable[T] = None):
if A is None:
A = []
A = list(A)
# Sorted ?
if not all(A[i] < A[i+1] for i in range(len(A) - 1)):
A = sorted(set(A))
# Unique ?
if not all(A[i] == A[i + 1] for i in range(len(A) - 1)):
A, A_cand = [], A
for a in A_cand:
if (not A) or (A[-1] != a):
A.append(a)
self.__build(A)
def __build(self, A: list = None):
if A is None:
A = list(self)
self._N = N = len(A)
K = 0
while self.BUCKET_RATIO * K * K < N:
K += 1
self._buckets: list[list[T]] = [A[N * i // K: N * (i + 1) // K] for i in range(K)]
self._last : list[T] = [bucket[-1] for bucket in self._buckets]
@property
def N(self) -> int:
return self._N
def __iter__(self) -> Iterator[T]:
for A in self._buckets:
yield from A
def __reversed__(self) -> Iterator[T]:
for A in reversed(self._buckets):
yield from reversed(A)
def __len__(self) -> int:
return self.N
def __bool__(self) -> bool:
return self.N > 0
def is_empty(self) -> bool:
""" 空集合かどうかを判断する.
Returns:
bool: 空集合ならば True
"""
return self.N == 0
def __str__(self) -> str:
return f"{{{', '.join([str(x) for x in self])}}}"
def __repr__(self) -> str:
return f"{self.__class__.__name__}({list(self)})"
def __find_bucket_index(self, x):
if self._last[-1] < x:
return len(self._last) -1
return bisect_left(self._last, x)
def _set_last(self, i: int, bucket: list[T]):
self._last[i] = bucket[-1]
def __contains__(self, x: T) -> bool:
if self.is_empty():
return False
i = self.__find_bucket_index(x)
A = self._buckets[i]
j = bisect_left(A, x)
return (j != len(A)) and (A[j] == x)
def add(self, x: T) -> bool:
""" 集合に要素 x を追加する.
Args:
x (T): 追加する要素
Returns:
bool: 追加による差分が発生すれば True
"""
if self.is_empty():
self._buckets=[[x]]
self._last = [x]
self._N += 1
return True
i = self.__find_bucket_index(x)
A = self._buckets[i]
j = bisect_left(A, x)
if (j != len(A)) and (A[j] == x):
return False # x が既に存在するので...
A.insert(j, x)
self._set_last(i, A)
self._N += 1
if len(A)>len(self._buckets)*self.REBUILD_RATIO:
self.__build()
return True
def discard(self, x: T) -> bool:
""" 集合から要素 x を削除する.
Args:
x (T): 削除する要素
Returns:
bool: 削除による差分が発生すれば True
"""
if self.is_empty():
return False
i = self.__find_bucket_index(x)
A = self._buckets[i]
j = bisect_left(A, x)
if not(j != len(A) and A[j] == x):
return False # x が存在しないので...
A.pop(j)
self._N -= 1
if A:
self._set_last(i, A)
else:
self.__build()
return True
def remove(self, x: T):
""" 集合から x を削除する.
Args:
x (T): 削除する要素
Raises:
KeyError: x が存在しないときに発生.
"""
if not self.discard(x):
raise KeyError(x)
#=== get, pop
def __getitem__(self, index):
if index<0:
index+=self.N
if index<0:
raise IndexError("index out of range")
for A in self._buckets:
if index<len(A):
return A[index]
index-=len(A)
else:
raise IndexError("index out of range")
def get_min(self) -> T:
""" 最小値を取得する.
Raises:
ValueError: 空集合であってはならない.
Returns:
T: 最小値
"""
if self.is_empty():
raise ValueError("This is empty set.")
return self._buckets[0][0]
def pop_min(self) -> T:
""" 最小値を削除し, その最小値を返り値とする.
Raises:
ValueError: 空集合であってはならない.
Returns:
T: 最小値
"""
if self.is_empty():
raise ValueError("This is empty set.")
A=self._buckets[0]
value=A.pop(0)
self._N -= 1
if len(A)==0:
self.__build()
return value
def get_max(self) -> T:
""" 最大値を取得する.
Raises:
ValueError: 空集合であってはならない.
Returns:
T: 最大値
"""
if self.is_empty():
return ValueError("This is empty set.")
return self._buckets[-1][-1]
def pop_max(self) -> T:
""" 最大値を削除し, その最大値を返り値とする.
Raises:
ValueError: 空集合であってはならない.
Returns:
T: 最大値
"""
if self.is_empty():
raise ValueError("This is empty set.")
A=self._buckets[-1]
value=A.pop(-1)
self._N -= 1
if A:
self._set_last(len(self._buckets) - 1, A)
else:
self.__build()
return value
#=== k-th element
def kth_min(self, k: int) -> T:
""" k (0-indexed) 番目に小さい値を求める.
Args:
k (int): 要素番号
Returns:
T: k 番目に小さい値
"""
if not(0 <= k < len(self)):
raise IndexError
return self[k]
def kth_max(self, k: int) -> T:
""" k (0-indexed) 番目に大きい値を求める.
Args:
k (int): 要素番号
Returns:
T: k 番目に大きい値
"""
if not(0 <= k < len(self)):
raise IndexError
return self[len(self) - 1 - k]
#=== previous, next
def previous(self, value: T, equal: bool = False) -> T | None:
""" value 未満の最大値を求める.
Args:
value (T): 閾値
equal (bool, optional): True にすると, "未満" が "以下"になる. Defaults to False.
Returns:
T | None: value 未満の最大値 (存在しない場合は None)
"""
if self.is_empty():
return None
if equal:
for bucket in reversed(self._buckets):
if bucket[0] <= value:
return bucket[bisect_right(bucket,value) - 1]
else:
for bucket in reversed(self._buckets):
if bucket[0] <value:
return bucket[bisect_left(bucket, value) - 1]
def next(self, value: T, equal: bool = False) -> T | None:
""" value より大きい最小値を求める.
Args:
value (T): 閾値
mode (bool, optional): True にすると, "より大きい" が "以上"になる. Defaults to False.
Returns:
T | None: value より大きい最小値 (存在しない場合は None)
"""
if self.is_empty():
return None
if equal:
for bucket in self._buckets:
if bucket[-1] >= value:
return bucket[bisect_left(bucket, value)]
else:
for bucket in self._buckets:
if bucket[-1] > value:
return bucket[bisect_right(bucket, value)]
#=== count
def less_count(self, value: T, equal: bool = False) -> int:
""" value 未満の元の個数を求める.
Args:
value (T): 閾値
equal (bool, optional): True にすると, "未満" が "以下" になる. Defaults to False.
Returns:
int: value 未満の元の個数
"""
if self.is_empty():
return 0
count=0
if equal:
for A in self._buckets:
if A[-1]>value:
return count+bisect_right(A, value)
count+=len(A)
else:
for A in self._buckets:
if A[-1]>=value:
return count+bisect_left(A, value)
count+=len(A)
return count
def more_count(self, value: T, equal: bool = False) -> int:
""" value より大きいの元の個数を求める.
Args:
value (T): 閾値
equal (bool, optional): True にすると, "より大きい" が "以上" になる. Defaults to False.
Returns:
int: value より大きい元の個数
"""
return self.N - self.less_count(value, not equal)
#===
def is_upper_bound(self, x: T, equal: bool = True) -> bool:
""" x はこの集合の上界 (任意の元 a に対して, a <= x) か ?
Args:
x (T): 値
equal (bool, optional): False にすると, 真の上界か? になる. Defaults to True.
Returns:
bool: 上界 ?
"""
if self.is_empty():
return True
a=self._buckets[-1][-1]
return (a<x) or (bool(equal) and a==x)
def is_lower_bound(self, x: T, equal: bool = True) -> bool:
""" x はこの集合の下界 (任意の元 a に対して, x <= a) か ?
Args:
x (T): 値
equal (bool, optional): False にすると, 真の下界か? になる. Defaults to True.
Returns:
bool: 下界 ?
"""
if self.is_empty():
return True
a=self._buckets[0][0]
return (x<a) or (bool(equal) and a==x)
#=== index
def index(self, value: T) -> int:
""" 要素 x の要素番号を求める.
Args:
value (T): 要素
Raises:
ValueError: 存在しない場合に発生
Returns:
int: 要素番号
"""
index=0
for A in self._buckets:
if A[-1]>value:
i=bisect_left(A, value)
if A[i]==value:
return index+i
else:
raise ValueError(f"{value} is not in Set")
index+=len(A)
raise ValueError(f"{value} is not in Set")
#==================================================
from typing import TypeVar, Generic, Callable, Generator
Group = TypeVar('Group')
class Binary_Indexed_Tree(Generic[Group]):
def __init__(self, L: list[Group], op: Callable[[Group, Group], Group], zero: Group, neg: Callable[[Group], Group]):
""" op を群 Group の演算として L から Binary Indexed Tree を生成する.
Args:
L (list[Group]): 初期状態
op (Callable[[Group, Group], Group]): 群演算
zero (Group): 群 Group における単位元 (任意の x in Group に対して, x + e = e + x = x となる e in Group)
neg (Callable[[Group], Group]): x in Group における逆元 (x + y = y + x = e となる y) を求める関数
"""
self.op=op
self.zero=zero
self.neg=neg
self.sub: Callable[[Group, Group], Group] = lambda x, y: self.op(x, self.neg(y))
self.N=N=len(L)
self.log=N.bit_length()-1
X=[zero]*(N+1)
for i in range(N):
p=i+1
X[p]=op(X[p],L[i])
q=p+(p&(-p))
if q<=N:
X[q]=op(X[q], X[p])
self.data=X
def get(self, k: int) -> Group:
""" 第 k 項を求める.
Args:
k (int): 要素の位置
Returns:
Group: 第 k 項
"""
return self.sum(k, k)
def add(self, k: int, x: Group) -> None:
""" 第 k 項に x を加え, 更新する.
Args:
k (int): 要素の位置
x (Group): 加える Group の要素
"""
data=self.data; op=self.op
p=k+1
while p<=self.N:
data[p]=op(self.data[p], x)
p+=p&(-p)
def update(self, k: int, x: Group) -> None:
""" 第 k 項を x に変えて更新する.
Args:
k (int): 要素の位置
x (Group): 更新先の値
"""
a=self.get(k)
y = self.sub(x, a)
self.add(k,y)
def sum(self, l: int, r: int) -> Group:
""" 第 l 項から第 r 項までの総和を求める (ただし, l != 0 のときは Group が群でなくてはならない).
Args:
l (int): 左端
r (int): 右端
Returns:
Group: 総和
"""
l=l+1 if 0<=l else 1
r=r+1 if r<self.N else self.N
if l>r:
return self.zero
elif l==1:
return self.__section(r)
else:
return self.sub(self.__section(r), self.__section(l - 1))
def __section(self, x: int) -> Group:
""" B[0] + B[1] + ... + B[x] を求める.
Args:
x (int): 右端
Returns:
Group: 総和
"""
data=self.data; op=self.op
S=self.zero
while x>0:
S=op(data[x], S)
x-=x&(-x)
return S
def all_sum(self) -> Group:
""" B[0] + B[1] + ... + B[len(B) - 1] を求める.
Returns:
Group: 総和
"""
return self.sum(0, self.N-1)
def binary_search(self, cond: Callable[[Group], bool]) -> int:
""" cond(B[0] + B[1] + ... + B[k]) が True になる最小の k を止める.
※ Group は順序群である必要がある.
※ cond(zero) = True のとき, 返り値は -1 とする.
※ cond(B[0] + ... + B[k]) なる k が (0 <= k < N に) 存在しない場合, 返り値は N とする.
Args:
cond (Callable[[Group], bool]): 単調増加な条件
Returns:
int: cond(B[0] + B[1] + ... + B[k]) が True になる最小の k
"""
if cond(self.zero):
return -1
j=0
t=1<<self.log
data=self.data; op=self.op
alpha=self.zero
while t>0:
if j+t<=self.N:
beta=op(alpha, data[j+t])
if not cond(beta):
alpha=beta
j+=t
t>>=1
return j
def __getitem__(self, index) -> Group:
if isinstance(index, int):
return self.get(index)
else:
return [self.get(t) for t in index]
def __setitem__(self, index: int, val: Group):
self.update(index, val)
def __iter__(self):
for k in range(self.N):
yield self.sum(k, k)
class Range_Binary_Indexed_Tree():
def __init__(self, L, op, zero, neg, mul):
self.op = op
self.zero = zero
self.neg = neg
self.mul = mul
self.N = len(L)
self.bit0 = Binary_Indexed_Tree(L, op, zero, neg)
self.bit1 = Binary_Indexed_Tree([zero]*len(L), op, zero, neg)
def get(self, k):
""" 第 k 要素の値を出力する.
k : 数列の要素
"""
return self.sum(k, k)
def add(self, k, x):
""" 第 k 要素に x を加え, 更新を行う.
k : 数列の要素
x : 加える値
index: 先頭の要素の番号
"""
self.bit0.add(k, x)
def update(self, k, x):
self.bit0.add(k, self.op(self.neg(self.get(k)), x))
def add_range(self, l, r, x):
""" 第 l 要素から第 r 要素までに一様に x を加える.
Args:
l (int): 左端
r (int): 右端
x:
"""
self.bit0.add(l, self.neg(self.mul(l, x)))
self.bit1.add(l, x)
if r < self.N - 1:
self.bit0.add(r + 1, self.mul(r + 1, x))
self.bit1.add(r + 1, self.neg(x))
def sum(self, l, r):
""" 第 l 要素から第 r 要素までの総和を求める.
※ l != index ならば, 群でなくてはならない.
l : 始まり
r : 終わり
index: 先頭の要素の番号
"""
if l > 0:
return self.op(self.__section(r), self.neg(self.__section(l - 1)))
else:
return self.__section(r)
def __section(self, k):
return self.op(self.bit0.sum(0, k), self.mul(k + 1, self.bit1.sum(0, k)))
def all_sum(self):
return self.sum(0, self.N - 1)
def __getitem__(self, index):
if isinstance(index, int):
return self.get(index)
else:
return [self.get(t) for t in index]
def __setitem__(self, index, value):
self.update(index, value)
def __iter__(self):
for ind in range(self.N):
yield self.sum(ind, ind)
#==================================================
def fetch_invs(N: int) -> list[int]:
invs = [0] * (N + 1)
invs[1] = 1
for i in range(2, N + 1):
q, r = divmod(Mod, i)
invs[i] = -q * invs[r] % Mod
return invs
#==================================================
def testcase():
from operator import neg
N, Q = map(int, input().split())
P = ["*"] + list(input().strip())
S = list(map(int, input().split()))
invs = fetch_invs(2 * N + 10)
P[N] = "B"
E = Sorted_Set[int]([i for i in range(N + 1) if (i == 0) or P[i] == "B"])
add = lambda x, y: (x + y) % Mod
mul = lambda x, y: (x * y) % Mod
A = Range_Binary_Indexed_Tree([0] * (N + 1), add, 0, neg, mul)
B = Range_Binary_Indexed_Tree([0] * (N + 1), add, 0, neg, mul)
S = Binary_Indexed_Tree[int](S, add, 0, neg)
for i in range(1, N + 1):
A.update(i, i - E.previous(i))
g = 0
for i in range(1, N + 1):
if P[i] == "G":
g += 1
continue
h = (g + 1) * (g + 2) * invs[2] * invs[i + 1] % Mod
B.add_range(i + 1, N, h)
g = 0
# 初期状態での Go on Back!! の解を求める.
ans = [0] * (Q + 1)
ans[0] = sum((A[i] + B[i]) * S[i] % Mod for i in range(N + 1)) % Mod
for q in range(1, Q + 1):
ans[q] = ans[q - 1]
t, *value = map(int, input().split())
if t == 1:
i, = value
delta = 0
if i == N:
continue
l = E.previous(i)
r = E.next(i)
if P[i] == "G":
# alpha に関する差分の計算
delta -= (i - l) * S.sum(i + 1, r) % Mod
A.add_range(i + 1, r, -(i - l))
# beta に関する差分の計算
g = r - l - 1
gl = i - l - 1
gr = g - (gl + 1)
h = (g + 1) * (g + 2) % Mod * invs[2 * (r + 1)] % Mod
hl = (gl + 1) * (gl + 2) % Mod * invs[2 * (i + 1)] % Mod
hr = (gr + 1) * (gr + 2) % Mod * invs[2 * (r + 1)] % Mod
delta -= h * S.sum(r + 1, N) % Mod
delta += hl * S.sum(i + 1, N) % Mod
delta += hr * S.sum(r + 1, N) % Mod
B.add_range(r + 1, N, -h)
B.add_range(i + 1, N, hl)
B.add_range(r + 1, N, hr)
P[i] = "B"
E.add(i)
else:
# alpha に関する差分の計算
delta += (i - l) * S.sum(i + 1, r) % Mod
A.add_range(i + 1, r, i - l)
# beta に関する差分の計算
gl = i - l - 1
gr = r - i - 1
g = gl + gr + 1
h = (g + 1) * (g + 2) % Mod * invs[2 * (r + 1)] % Mod
hl = (gl + 1) * (gl + 2) % Mod * invs[2 * (i + 1)] % Mod
hr = (gr + 1) * (gr + 2) % Mod * invs[2 * (r + 1)] % Mod
delta += h * S.sum(r + 1, N) % Mod
delta -= hl * S.sum(i + 1, N) % Mod
delta -= hr * S.sum(r + 1, N) % Mod
B.add_range(r + 1, N, h)
B.add_range(i + 1, N, -hl)
B.add_range(r + 1, N, -hr)
P[i] = "G"
E.remove(i)
ans[q] += delta
elif t == 2:
i, b = value
ans[q] += (A[i] + B[i]) * (b - S[i])
S.update(i, b)
ans[q] %= Mod
return ans[1:]
#==================================================
def solve():
T = int(input())
to_string = lambda ans: " ".join(map(str, ans))
write("\n".join(map(to_string, [testcase() for _ in range(T)])))
#==================================================
import sys
input = sys.stdin.readline
write = sys.stdout.write
Mod = 998244353
solve()
Kazun