結果

問題 No.3239 Omnibus
ユーザー Nauclhlt🪷
提出日時 2025-08-14 22:51:23
言語 Python3
(3.13.1 + numpy 2.2.1 + scipy 1.14.1)
結果
AC  
実行時間 9,463 ms / 10,000 ms
コード長 6,297 bytes
コンパイル時間 252 ms
コンパイル使用メモリ 12,800 KB
実行使用メモリ 66,252 KB
最終ジャッジ日時 2025-08-14 23:54:05
合計ジャッジ時間 125,122 ms
ジャッジサーバーID
(参考情報)
judge5 / judge1
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 33
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
sys.setrecursionlimit(1 << 25)

class Set:
    class Node:
        __slots__ = ['value', 'sum', 'left', 'right', 'height', 'size']

        def __init__(self, value):
            self.value = value
            self.sum = value
            self.left = None
            self.right = None
            self.height = 1
            self.size = 1

    def __init__(self):
        self.root = None

    # ---------------- AVL 基本操作 ----------------
    def _height(self, node):
        return node.height if node else 0

    def _size(self, node):
        return node.size if node else 0

    def _sum(self, node):
        return node.sum if node else 0

    def _update(self, node):
        if not node: return
        node.height = max(self._height(node.left), self._height(node.right)) + 1
        node.size = self._size(node.left) + self._size(node.right) + 1
        node.sum = self._sum(node.left) + self._sum(node.right) + node.value

    def _rotate_left(self, node):
        r = node.right
        node.right = r.left
        r.left = node
        self._update(node)
        self._update(r)
        return r

    def _rotate_right(self, node):
        l = node.left
        node.left = l.right
        l.right = node
        self._update(node)
        self._update(l)
        return l

    def _balance(self, node):
        if not node:
            return node
        self._update(node)
        balance = self._height(node.left) - self._height(node.right)
        if balance >= 2:
            if self._height(node.left.left) < self._height(node.left.right):
                node.left = self._rotate_left(node.left)
            return self._rotate_right(node)
        if balance <= -2:
            if self._height(node.right.right) < self._height(node.right.left):
                node.right = self._rotate_right(node.right)
            return self._rotate_left(node)
        return node

    # ---------------- 要素追加 ----------------
    def _add(self, node, value):
        if not node:
            return self.Node(value)
        if value < node.value:
            node.left = self._add(node.left, value)
        else:
            node.right = self._add(node.right, value)
        return self._balance(node)

    def add(self, value):
        self.root = self._add(self.root, value)

    # ---------------- 要素削除 ----------------
    def _get_max(self, node):
        while node.right:
            node = node.right
        return node

    def _remove(self, node, value):
        if not node:
            return None
        if value < node.value:
            node.left = self._remove(node.left, value)
        elif value > node.value:
            node.right = self._remove(node.right, value)
        else:
            if not node.left: return node.right
            if not node.right: return node.left
            max_left = self._get_max(node.left)
            node.value = max_left.value
            node.left = self._remove(node.left, max_left.value)
        return self._balance(node)

    def remove(self, value):
        self.root = self._remove(self.root, value)

    # ---------------- lower_bound ----------------
    def lower_bound(self, value):
        node = self.root
        res = 0
        while node:
            if value <= node.value:
                node = node.left
            else:
                res += self._size(node.left) + 1
                node = node.right
        return res

    # ---------------- prefix_sum ----------------
    def prefix_sum(self, r):
        node = self.root
        res = 0
        while node and r > 0:
            left_sz = self._size(node.left)
            if r <= left_sz:
                node = node.left
            elif r == left_sz + 1:
                res += self._sum(node.left) + node.value
                break
            else:
                res += self._sum(node.left) + node.value
                r -= left_sz + 1
                node = node.right
        return res

# Precompute powers for encoding
POW26 = [1, 26, 26*26]

def encode(s, start):
    if isinstance(s, str):
        return (ord(s[start]) - 97) * POW26[2] + (ord(s[start+1]) - 97) * POW26[1] + (ord(s[start+2]) - 97)
    elif isinstance(s, (list, bytearray)):
        return (s[start] - 97) * POW26[2] + (s[start+1] - 97) * POW26[1] + (s[start+2] - 97)
    else:
        raise TypeError("Unsupported type for encode")

def main():
    input = sys.stdin.buffer.readline
    output = sys.stdout.buffer.write
    
    data = input().split()
    N = int(data[0])
    Q = int(data[1])
    S = input().decode().strip()
    
    state = bytearray(S, 'ascii')
    indices = [None] * (26 * 26 * 26)
    
    for i in range(N - 2):
        code = encode(state, i)
        if indices[code] is None:
            indices[code] = Set()
        indices[code].add(i + 1)
    
    out_lines = []
    for _ in range(Q):
        data = input().split()
        q = int(data[0])
        
        if q == 1:
            k = int(data[1]) - 1
            x = data[2].decode()
            
            # 境界チェックを厳密に行う
            start = max(0, k - 2)
            end = min(N - 3, k)  # N-2までのインデックスなのでN-3が最大
            
            for j in range(start, end + 1):
                c = encode(state, j)
                if indices[c] is not None:
                    indices[c].remove(j + 1)
            
            state[k] = ord(x[0])
            
            for j in range(start, end + 1):
                c = encode(state, j)
                if indices[c] is None:
                    indices[c] = Set()
                indices[c].add(j + 1)
        
        elif q == 2:
            l = int(data[1])
            r = int(data[2])
            a = data[3].decode()  # 修正: data[3]から取得
            code = encode(a, 0)
            
            if indices[code] is None:
                out_lines.append(b"0\n")
            else:
                p = indices[code].lower_bound(r - 1)
                s = indices[code].lower_bound(l)
                ans = indices[code].prefix_sum(p) - indices[code].prefix_sum(s)
                ans -= (p - s) * (l - 1)
                out_lines.append(f"{ans}\n".encode())
    
    output(b"".join(out_lines))

if __name__ == "__main__":
    main()
0