結果

問題 No.3239 Omnibus
ユーザー Nauclhlt🪷
提出日時 2025-08-14 22:32:59
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 5,244 bytes
コンパイル時間 434 ms
コンパイル使用メモリ 81,796 KB
実行使用メモリ 191,092 KB
最終ジャッジ日時 2025-08-14 22:33:08
合計ジャッジ時間 8,581 ms
ジャッジサーバーID
(参考情報)
judge4 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 3 WA * 1 TLE * 1 -- * 27
権限があれば一括ダウンロードができます

ソースコード

diff #

class AVLSet:
    class Node:
        __slots__ = ['val', 'sum', 'left', 'right', 'bias', 'height', 'size']
        def __init__(self, val):
            self.val = val
            self.sum = val
            self.left = None
            self.right = None
            self.bias = 0
            self.height = 1
            self.size = 1

    def __init__(self):
        self.root = None

    # ===== 基本操作 =====
    def add(self, val):
        self.root = self._add(self.root, val)

    def remove(self, val):
        self.root = self._remove(self.root, val)

    def lower_bound(self, val):
        node = self.root
        res = 0
        idx = 0
        while node:
            if val <= node.val:
                node = node.left
            else:
                left_size = node.left.size if node.left else 0
                idx += left_size + 1
                node = node.right
        return idx

    def prefix_sum(self, r):
        return self._prefix_sum(self.root, r)

    # ===== 内部関数 =====
    def _update(self, node):
        if node is None: return
        lh = node.left.height if node.left else 0
        rh = node.right.height if node.right else 0
        node.height = max(lh, rh) + 1
        node.bias = lh - rh
        ls = node.left.size if node.left else 0
        rs = node.right.size if node.right else 0
        node.size = ls + rs + 1
        lsum = node.left.sum if node.left else 0
        rsum = node.right.sum if node.right else 0
        node.sum = lsum + rsum + node.val

    def _rotate_left(self, node):
        r = node.right
        node.right = r.left
        r.left = node
        self._update(node)
        self._update(r)
        return r

    def _rotate_right(self, node):
        l = node.left
        node.left = l.right
        l.right = node
        self._update(node)
        self._update(l)
        return l

    def _balance(self, node):
        if node.bias < -1:
            if node.right and node.right.bias > 0:
                node.right = self._rotate_right(node.right)
            return self._rotate_left(node)
        elif node.bias > 1:
            if node.left and node.left.bias < 0:
                node.left = self._rotate_left(node.left)
            return self._rotate_right(node)
        return node

    def _add(self, node, val):
        if node is None:
            return AVLSet.Node(val)
        if val < node.val:
            node.left = self._add(node.left, val)
        else:
            node.right = self._add(node.right, val)
        self._update(node)
        return self._balance(node)

    def _remove(self, node, val):
        if node is None:
            return None
        if val < node.val:
            node.left = self._remove(node.left, val)
        elif val > node.val:
            node.right = self._remove(node.right, val)
        else:
            if node.left and node.right:
                pred = node.left
                while pred.right:
                    pred = pred.right
                node.val = pred.val
                node.left = self._remove(node.left, pred.val)
            elif node.left:
                return node.left
            elif node.right:
                return node.right
            else:
                return None
        self._update(node)
        return self._balance(node)

    def _prefix_sum(self, node, r):
        if node is None or r <= 0:
            return 0
        left_size = node.left.size if node.left else 0
        if r <= left_size:
            return self._prefix_sum(node.left, r)
        s = node.left.sum if node.left else 0
        if r == left_size + 1:
            return s + node.val
        return s + node.val + self._prefix_sum(node.right, r - left_size - 1)

# ===== solve 関数 =====
def encode(s, i):
    return (ord(s[i]) - ord('a')) * 26 * 26 + (ord(s[i+1]) - ord('a')) * 26 + (ord(s[i+2]) - ord('a'))

def solve():
    import sys
    input = sys.stdin.readline

    N, Q = map(int, input().split())
    S = input().strip()
    state = list(S)
    indices = [None] * (26*26*26)

    for i in range(N-2):
        code = encode(S, i)
        if indices[code] is None:
            indices[code] = AVLSet()
        indices[code].add(i+1)

    for _ in range(Q):
        tmp = input().split()
        q = int(tmp[0])
        if q == 1:
            k = int(tmp[1]) - 1
            x = tmp[2]
            for j in range(k-2, k+1):
                if 0 <= j < N-2:
                    c = encode(state, j)
                    indices[c].remove(j+1)
            state[k] = x
            for j in range(k-2, k+1):
                if 0 <= j < N-2:
                    c = encode(state, j)
                    if indices[c] is None:
                        indices[c] = AVLSet()
                    indices[c].add(j+1)
        else:
            l = int(tmp[1])
            r = int(tmp[2])
            a = tmp[3]
            code = encode(a, 0)
            if indices[code] is None:
                print(0)
            else:
                p = indices[code].lower_bound(r)
                s = indices[code].lower_bound(l)
                ans = indices[code].prefix_sum(p) - indices[code].prefix_sum(s)
                ans -= (p - s) * (l - 1)
                print(ans)

solve();
0