結果

問題 No.3439 [Cherry 8th Tune] どの頂点にいた頃に戻りたいのか?
コンテスト
ユーザー 👑 Kazun
提出日時 2025-11-16 21:46:20
言語 PyPy3
(7.3.17)
結果
RE  
実行時間 -
コード長 21,475 bytes
記録
記録タグの例:
初AC ショートコード 純ショートコード 純主流ショートコード 最速実行時間
コンパイル時間 329 ms
コンパイル使用メモリ 82,536 KB
実行使用メモリ 75,324 KB
最終ジャッジ日時 2026-01-23 21:04:41
合計ジャッジ時間 6,807 ms
ジャッジサーバーID
(参考情報)
judge3 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample RE * 2
other RE * 37
権限があれば一括ダウンロードができます

ソースコード

diff #
raw source code

# Reference: https://qiita.com/tatyam/items/492c70ac4c955c055602
# ※ 計算量が O(sqrt(N)) per query なので, 過度な期待はしないこと.

from bisect import bisect_left, bisect_right
from typing import Generic, Iterable, Iterator, TypeVar
T = TypeVar('T')

class Sorted_Set(Generic[T]):
    BUCKET_RATIO=50
    REBUILD_RATIO=170

    def __init__(self, A: Iterable[T] = None):
        if A is None:
            A = []

        A = list(A)

        # Sorted ?
        if not all(A[i] < A[i+1] for i in range(len(A) - 1)):
            A = sorted(set(A))

        # Unique ?
        if not all(A[i] == A[i + 1] for i in range(len(A) - 1)):
            A, A_cand = [], A
            for a in A_cand:
                if (not A) or (A[-1] != a):
                    A.append(a)

        self.__build(A)

    def __build(self, A: list = None):
        if A is None:
            A = list(self)

        self._N = N = len(A)
        K = 0
        while self.BUCKET_RATIO * K * K < N:
            K += 1

        self._buckets: list[list[T]] = [A[N * i // K: N * (i + 1) // K] for i in range(K)]
        self._last : list[T] = [bucket[-1] for bucket in self._buckets]

    @property
    def N(self) -> int:
        return self._N

    def __iter__(self) -> Iterator[T]:
        for A in self._buckets:
            yield from A

    def __reversed__(self) -> Iterator[T]:
        for A in reversed(self._buckets):
            yield from reversed(A)

    def __len__(self) -> int:
        return self.N

    def __bool__(self) -> bool:
        return self.N > 0

    def is_empty(self) -> bool:
        """ 空集合かどうかを判断する.

        Returns:
            bool: 空集合ならば True
        """
        return self.N == 0

    def __str__(self) -> str:
        return f"{{{', '.join([str(x) for x in self])}}}"

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({list(self)})"

    def __find_bucket_index(self, x):
        if self._last[-1] < x:
            return len(self._last) -1

        return bisect_left(self._last, x)

    def _set_last(self, i: int, bucket: list[T]):
        self._last[i] = bucket[-1]

    def __contains__(self, x: T) -> bool:
        if self.is_empty():
            return False

        i = self.__find_bucket_index(x)
        A = self._buckets[i]
        j = bisect_left(A, x)
        return (j != len(A)) and (A[j] == x)

    def add(self, x: T) -> bool:
        """ 集合に要素 x を追加する.

        Args:
            x (T): 追加する要素

        Returns:
            bool: 追加による差分が発生すれば True
        """

        if self.is_empty():
            self._buckets=[[x]]
            self._last = [x]
            self._N += 1
            return True

        i = self.__find_bucket_index(x)
        A = self._buckets[i]
        j = bisect_left(A, x)

        if (j != len(A)) and (A[j] == x):
            return False # x が既に存在するので...

        A.insert(j, x)
        self._set_last(i, A)
        self._N += 1

        if len(A)>len(self._buckets)*self.REBUILD_RATIO:
            self.__build()

        return True

    def discard(self, x: T) -> bool:
        """ 集合から要素 x を削除する.

        Args:
            x (T): 削除する要素

        Returns:
            bool: 削除による差分が発生すれば True
        """

        if self.is_empty():
            return False

        i = self.__find_bucket_index(x)
        A = self._buckets[i]
        j = bisect_left(A, x)

        if not(j != len(A) and A[j] == x):
            return False # x が存在しないので...

        A.pop(j)
        self._N -= 1

        if A:
            self._set_last(i, A)
        else:
            self.__build()

        return True

    def remove(self, x: T):
        """ 集合から x を削除する.

        Args:
            x (T): 削除する要素

        Raises:
            KeyError: x が存在しないときに発生.
        """
        if not self.discard(x):
            raise KeyError(x)

    #=== get, pop

    def __getitem__(self, index):
        if index<0:
            index+=self.N
            if index<0:
                raise IndexError("index out of range")

        for A in self._buckets:
            if index<len(A):
                return A[index]
            index-=len(A)
        else:
            raise IndexError("index out of range")

    def get_min(self) -> T:
        """ 最小値を取得する.

        Raises:
            ValueError: 空集合であってはならない.

        Returns:
            T: 最小値
        """

        if self.is_empty():
            raise ValueError("This is empty set.")

        return self._buckets[0][0]

    def pop_min(self) -> T:
        """ 最小値を削除し, その最小値を返り値とする.

        Raises:
            ValueError: 空集合であってはならない.

        Returns:
            T: 最小値
        """

        if self.is_empty():
            raise ValueError("This is empty set.")

        A=self._buckets[0]
        value=A.pop(0)
        self._N -= 1

        if len(A)==0:
            self.__build()

        return value

    def get_max(self) -> T:
        """ 最大値を取得する.

        Raises:
            ValueError: 空集合であってはならない.

        Returns:
            T: 最大値
        """

        if self.is_empty():
            return ValueError("This is empty set.")

        return self._buckets[-1][-1]

    def pop_max(self) -> T:
        """ 最大値を削除し, その最大値を返り値とする.

        Raises:
            ValueError: 空集合であってはならない.

        Returns:
            T: 最大値
        """

        if self.is_empty():
            raise ValueError("This is empty set.")

        A=self._buckets[-1]
        value=A.pop(-1)
        self._N -= 1

        if A:
            self._set_last(len(self._buckets) - 1, A)
        else:
            self.__build()

        return value

    #=== k-th element
    def kth_min(self, k: int) -> T:
        """ k (0-indexed) 番目に小さい値を求める.

        Args:
            k (int): 要素番号

        Returns:
            T: k 番目に小さい値
        """

        if not(0 <= k < len(self)):
            raise IndexError

        return self[k]

    def kth_max(self, k: int) -> T:
        """ k (0-indexed) 番目に大きい値を求める.

        Args:
            k (int): 要素番号

        Returns:
            T: k 番目に大きい値
        """

        if not(0 <= k < len(self)):
            raise IndexError

        return self[len(self) - 1 - k]

    #=== previous, next

    def previous(self, value: T, equal: bool = False) -> T | None:
        """ value 未満の最大値を求める.

        Args:
            value (T): 閾値
            equal (bool, optional): True にすると, "未満" が "以下"になる. Defaults to False.

        Returns:
            T | None: value 未満の最大値 (存在しない場合は None)
        """

        if self.is_empty():
            return None

        if equal:
            for bucket in reversed(self._buckets):
                if bucket[0] <= value:
                    return bucket[bisect_right(bucket,value) - 1]
        else:
            for bucket in reversed(self._buckets):
                if bucket[0] <value:
                    return bucket[bisect_left(bucket, value) - 1]

    def next(self, value: T, equal: bool = False) -> T | None:
        """ value より大きい最小値を求める.

        Args:
            value (T): 閾値
            mode (bool, optional): True にすると, "より大きい" が "以上"になる. Defaults to False.

        Returns:
            T | None: value より大きい最小値 (存在しない場合は None)
        """

        if self.is_empty():
            return None

        if equal:
            for bucket in self._buckets:
                if bucket[-1] >= value:
                    return bucket[bisect_left(bucket, value)]
        else:
            for bucket in self._buckets:
                if bucket[-1] > value:
                    return bucket[bisect_right(bucket, value)]

    #=== count
    def less_count(self, value: T, equal: bool = False) -> int:
        """ value 未満の元の個数を求める.

        Args:
            value (T): 閾値
            equal (bool, optional): True にすると, "未満" が "以下" になる. Defaults to False.

        Returns:
            int: value 未満の元の個数
        """

        if self.is_empty():
            return 0

        count=0
        if equal:
            for A in self._buckets:
                if A[-1]>value:
                    return count+bisect_right(A, value)
                count+=len(A)
        else:
            for A in self._buckets:
                if A[-1]>=value:
                    return count+bisect_left(A, value)
                count+=len(A)
        return count

    def more_count(self, value: T, equal: bool = False) -> int:
        """ value より大きいの元の個数を求める.

        Args:
            value (T): 閾値
            equal (bool, optional): True にすると, "より大きい" が "以上" になる. Defaults to False.

        Returns:
            int: value より大きい元の個数
        """

        return self.N - self.less_count(value, not equal)

    #===
    def is_upper_bound(self, x: T, equal: bool = True) -> bool:
        """ x はこの集合の上界 (任意の元 a に対して, a <= x) か ?

        Args:
            x (T): 値
            equal (bool, optional): False にすると, 真の上界か? になる. Defaults to True.

        Returns:
            bool: 上界 ?
        """

        if self.is_empty():
            return True

        a=self._buckets[-1][-1]
        return (a<x) or (bool(equal) and a==x)

    def is_lower_bound(self, x: T, equal: bool = True) -> bool:
        """ x はこの集合の下界 (任意の元 a に対して, x <= a) か ?

        Args:
            x (T): 値
            equal (bool, optional): False にすると, 真の下界か? になる. Defaults to True.

        Returns:
            bool: 下界 ?
        """

        if self.is_empty():
            return True

        a=self._buckets[0][0]
        return (x<a) or (bool(equal) and a==x)


    #=== index
    def index(self, value: T) -> int:
        """ 要素 x の要素番号を求める.

        Args:
            value (T): 要素

        Raises:
            ValueError: 存在しない場合に発生

        Returns:
            int: 要素番号
        """

        index=0
        for A in self._buckets:
            if A[-1]>value:
                i=bisect_left(A, value)
                if A[i]==value:
                    return index+i
                else:
                    raise ValueError(f"{value} is not in Set")
            index+=len(A)
        raise ValueError(f"{value} is not in Set")

#==================================================
from typing import TypeVar, Generic, Callable, Generator

Group = TypeVar('Group')
class Binary_Indexed_Tree(Generic[Group]):
    def __init__(self, L: list[Group], op: Callable[[Group, Group], Group], zero: Group, neg: Callable[[Group], Group]):
        """ op を群 Group の演算として L から Binary Indexed Tree を生成する.

        Args:
            L (list[Group]): 初期状態
            op (Callable[[Group, Group], Group]): 群演算
            zero (Group): 群 Group における単位元 (任意の x in Group に対して, x + e = e + x = x となる e in Group)
            neg (Callable[[Group], Group]): x in Group における逆元 (x + y = y + x = e となる y) を求める関数
        """

        self.op=op
        self.zero=zero
        self.neg=neg
        self.sub: Callable[[Group, Group], Group] = lambda x, y: self.op(x, self.neg(y))
        self.N=N=len(L)
        self.log=N.bit_length()-1

        X=[zero]*(N+1)

        for i in range(N):
            p=i+1
            X[p]=op(X[p],L[i])
            q=p+(p&(-p))
            if q<=N:
                X[q]=op(X[q], X[p])
        self.data=X

    def get(self, k: int) -> Group:
        """ 第 k 項を求める.

        Args:
            k (int): 要素の位置

        Returns:
            Group: 第 k 項
        """
        return self.sum(k, k)

    def add(self, k: int, x: Group) -> None:
        """ 第 k 項に x を加え, 更新する.

        Args:
            k (int): 要素の位置
            x (Group): 加える Group の要素
        """

        data=self.data; op=self.op
        p=k+1
        while p<=self.N:
            data[p]=op(self.data[p], x)
            p+=p&(-p)

    def update(self, k: int, x: Group) -> None:
        """ 第 k 項を x に変えて更新する.

        Args:
            k (int): 要素の位置
            x (Group): 更新先の値
        """

        a=self.get(k)
        y = self.sub(x, a)

        self.add(k,y)

    def sum(self, l: int, r: int) -> Group:
        """ 第 l 項から第 r 項までの総和を求める (ただし, l != 0 のときは Group が群でなくてはならない).

        Args:
            l (int): 左端
            r (int): 右端

        Returns:
            Group: 総和
        """

        l=l+1 if 0<=l else 1
        r=r+1 if r<self.N else self.N

        if l>r:
            return self.zero
        elif l==1:
            return self.__section(r)
        else:
            return self.sub(self.__section(r), self.__section(l - 1))

    def __section(self, x: int) -> Group:
        """ B[0] + B[1] + ... + B[x] を求める.

        Args:
            x (int): 右端

        Returns:
            Group: 総和
        """

        data=self.data; op=self.op
        S=self.zero
        while x>0:
            S=op(data[x], S)
            x-=x&(-x)
        return S

    def all_sum(self) -> Group:
        """ B[0] + B[1] + ... + B[len(B) - 1] を求める.

        Returns:
            Group: 総和
        """
        return self.sum(0, self.N-1)

    def binary_search(self, cond: Callable[[Group], bool]) -> int:
        """ cond(B[0] + B[1] + ... + B[k]) が True になる最小の k を止める.

        ※ Group は順序群である必要がある.
        ※ cond(zero) = True のとき, 返り値は -1 とする.
        ※ cond(B[0] + ... + B[k]) なる k が (0 <= k < N に) 存在しない場合, 返り値は N とする.

        Args:
            cond (Callable[[Group], bool]): 単調増加な条件

        Returns:
            int: cond(B[0] + B[1] + ... + B[k]) が True になる最小の k
        """

        if cond(self.zero):
            return -1

        j=0
        t=1<<self.log
        data=self.data; op=self.op
        alpha=self.zero

        while t>0:
            if j+t<=self.N:
                beta=op(alpha, data[j+t])
                if not cond(beta):
                    alpha=beta
                    j+=t
            t>>=1

        return j

    def __getitem__(self, index) -> Group:
        if isinstance(index, int):
            return self.get(index)
        else:
            return [self.get(t) for t in index]

    def __setitem__(self, index: int, val: Group):
        self.update(index, val)

    def __iter__(self):
        for k in range(self.N):
            yield self.sum(k, k)

class Range_Binary_Indexed_Tree():
    def __init__(self, L, op, zero, neg, mul):

        self.op = op
        self.zero = zero
        self.neg = neg
        self.mul = mul
        self.N = len(L)

        self.bit0 = Binary_Indexed_Tree(L, op, zero, neg)
        self.bit1 = Binary_Indexed_Tree([zero]*len(L), op, zero, neg)

    def get(self, k):
        """ 第 k 要素の値を出力する.

        k    : 数列の要素
        """
        return self.sum(k, k)

    def add(self, k, x):
        """ 第 k 要素に x を加え, 更新を行う.
        k    : 数列の要素
        x    : 加える値
        index: 先頭の要素の番号
        """
        self.bit0.add(k, x)

    def update(self, k, x):
        self.bit0.add(k, self.op(self.neg(self.get(k)), x))

    def add_range(self, l, r, x):
        """ 第 l 要素から第 r 要素までに一様に x を加える.

        Args:
            l (int): 左端
            r (int): 右端
            x:
        """

        self.bit0.add(l, self.neg(self.mul(l, x)))
        self.bit1.add(l, x)
        if r < self.N - 1:
            self.bit0.add(r + 1, self.mul(r + 1, x))
            self.bit1.add(r + 1, self.neg(x))

    def sum(self, l, r):
        """ 第 l 要素から第 r 要素までの総和を求める.
        ※ l != index ならば, 群でなくてはならない.
        l : 始まり
        r   : 終わり
        index: 先頭の要素の番号
        """
        if l > 0:
            return self.op(self.__section(r), self.neg(self.__section(l - 1)))
        else:
            return self.__section(r)

    def __section(self, k):
        return self.op(self.bit0.sum(0, k), self.mul(k + 1, self.bit1.sum(0, k)))

    def all_sum(self):
        return self.sum(0, self.N - 1)

    def __getitem__(self, index):
        if isinstance(index, int):
            return self.get(index)
        else:
            return [self.get(t) for t in index]

    def __setitem__(self, index, value):
        self.update(index, value)

    def __iter__(self):
        for ind in range(self.N):
            yield self.sum(ind, ind)

#==================================================
def fetch_invs(N: int) -> list[int]:
    invs = [0] * (N + 1)
    invs[1] = 1
    for i in range(2, N + 1):
        q, r = divmod(Mod, i)
        invs[i] = -q * invs[r] % Mod
    return invs

#==================================================
def testcase():
    from operator import neg

    N, Q = map(int, input().split())
    P = ["*"] + list(input().strip())
    S = list(map(int, input().split()))

    invs = fetch_invs(2 * N + 10)

    P[N] = "B"
    E = Sorted_Set[int]([i for i in range(N + 1) if (i == 0) or P[i] == "B"])

    add = lambda x, y: (x + y) % Mod
    mul = lambda x, y: (x * y) % Mod

    A = Range_Binary_Indexed_Tree([0] * (N + 1), add, 0, neg, mul)
    B = Range_Binary_Indexed_Tree([0] * (N + 1), add, 0, neg, mul)
    S = Binary_Indexed_Tree[int](S, add, 0, neg)

    for i in range(1, N + 1):
        A.update(i, i - E.previous(i))

    g = 0
    for i in range(1, N + 1):
        if P[i] == "G":
            g += 1
            continue

        h = (g + 1) * (g + 2) * invs[2] * invs[i + 1] % Mod
        B.add_range(i + 1, N, h)
        g = 0

    # 初期状態での Go on Back!! の解を求める.
    ans = [0] * (Q + 1)
    ans[0] = sum((A[i] + B[i]) * S[i] % Mod for i in range(N + 1)) % Mod

    for q in range(1, Q + 1):
        ans[q] = ans[q - 1]
        t, *value = map(int, input().split())
        if t == 1:
            i, = value
            delta = 0
            if i == N:
                continue

            l = E.previous(i)
            r = E.next(i)

            if P[i] == "G":
                # alpha に関する差分の計算
                delta -= (i - l) * S.sum(i + 1, r) % Mod
                A.add_range(i + 1, r, -(i - l))

                # beta に関する差分の計算
                g = r - l - 1
                gl = i - l - 1
                gr = g - (gl + 1)

                h = (g + 1) * (g + 2) % Mod * invs[2 * (r + 1)] % Mod
                hl = (gl + 1) * (gl + 2) % Mod * invs[2 * (i + 1)] % Mod
                hr = (gr + 1) * (gr + 2) % Mod * invs[2 * (r + 1)] % Mod

                delta -= h * S.sum(r + 1, N) % Mod
                delta += hl * S.sum(i + 1, N) % Mod
                delta += hr * S.sum(r + 1, N) % Mod

                B.add_range(r + 1, N, -h)
                B.add_range(i + 1, N, hl)
                B.add_range(r + 1, N, hr)

                P[i] = "B"
                E.add(i)
            else:
                # alpha に関する差分の計算
                delta += (i - l) * S.sum(i + 1, r) % Mod
                A.add_range(i + 1, r, i - l)

                # beta に関する差分の計算
                gl = i - l - 1
                gr = r - i - 1
                g = gl + gr + 1

                h = (g + 1) * (g + 2)  % Mod * invs[2 * (r + 1)] % Mod
                hl = (gl + 1) * (gl + 2) % Mod * invs[2 * (i + 1)] % Mod
                hr = (gr + 1) * (gr + 2) % Mod * invs[2 * (r + 1)] % Mod

                delta += h * S.sum(r + 1, N) % Mod
                delta -= hl * S.sum(i + 1, N) % Mod
                delta -= hr * S.sum(r + 1, N) % Mod

                B.add_range(r + 1, N, h)
                B.add_range(i + 1, N, -hl)
                B.add_range(r + 1, N, -hr)

                P[i] = "G"
                E.remove(i)

            ans[q] += delta
        elif t == 2:
            i, b = value
            ans[q] += (A[i] + B[i]) * (b - S[i])
            S.update(i, b)

        ans[q] %= Mod

    return ans[1:]

#==================================================
def solve():
    T = int(input())
    to_string = lambda ans: " ".join(map(str, ans))
    write("\n".join(map(to_string, [testcase() for _ in range(T)])))

#==================================================
import sys
input = sys.stdin.readline
write = sys.stdout.write

Mod = 998244353

solve()
0