結果

問題 No.3239 Omnibus
ユーザー Nauclhlt🪷
提出日時 2025-08-14 22:54:33
言語 PyPy3
(7.3.15)
結果
RE  
実行時間 -
コード長 7,806 bytes
コンパイル時間 329 ms
コンパイル使用メモリ 82,048 KB
実行使用メモリ 130,048 KB
最終ジャッジ日時 2025-08-14 22:54:57
合計ジャッジ時間 14,637 ms
ジャッジサーバーID
(参考情報)
judge2 / judge3
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample RE * 1
other WA * 2 RE * 30
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
sys.setrecursionlimit(1 << 25)

class Set:
    def __init__(self):
        self.root = None

    class Node:
        __slots__ = ['value', 'sum', 'size', 'height', 'left', 'right']

        def __init__(self, value):
            self.value = value
            self.sum = value
            self.size = 1
            self.height = 1
            self.left = None
            self.right = None

    # ---------------- ヘルパー ----------------
    def _update(self, node):
        if not node:
            return
        lh = node.left.height if node.left else 0
        rh = node.right.height if node.right else 0
        ls = node.left.size if node.left else 0
        rs = node.right.size if node.right else 0
        ls_sum = node.left.sum if node.left else 0
        rs_sum = node.right.sum if node.right else 0
        node.height = max(lh, rh) + 1
        node.size = ls + rs + 1
        node.sum = ls_sum + rs_sum + node.value

    def _rotate_left(self, node):
        r = node.right
        node.right = r.left
        r.left = node
        self._update(node)
        self._update(r)
        return r

    def _rotate_right(self, node):
        l = node.left
        node.left = l.right
        l.right = node
        self._update(node)
        self._update(l)
        return l

    def _balance(self, node):
        self._update(node)
        balance = (node.left.height if node.left else 0) - (node.right.height if node.right else 0)
        if balance >= 2:
            if node.left and (node.left.left.height if node.left.left else 0) < (node.left.right.height if node.left.right else 0):
                node.left = self._rotate_left(node.left)
            return self._rotate_right(node)
        if balance <= -2:
            if node.right and (node.right.right.height if node.right.right else 0) < (node.right.left.height if node.right.left else 0):
                node.right = self._rotate_right(node.right)
            return self._rotate_left(node)
        return node

    # ---------------- add (非再帰) ----------------
    def add(self, value):
        if not self.root:
            self.root = self.Node(value)
            return

        path = []
        node = self.root
        while True:
            path.append(node)
            if value < node.value:
                if node.left is None:
                    node.left = self.Node(value)
                    break
                node = node.left
            else:
                if node.right is None:
                    node.right = self.Node(value)
                    break
                node = node.right

        # 下から上へバランス
        for i in reversed(range(len(path))):
            path[i] = self._balance(path[i])
            if i == 0:
                self.root = path[i]
            else:
                parent = path[i - 1]
                if parent.left and parent.left.value == path[i].value or value < parent.value:
                    parent.left = path[i]
                else:
                    parent.right = path[i]

    # ---------------- remove (非再帰) ----------------
    def remove(self, value):
        node = self.root
        path = []
        parent = None
        is_left = True

        # 探索
        while node and node.value != value:
            path.append(node)
            parent = node
            if value < node.value:
                node = node.left
                is_left = True
            else:
                node = node.right
                is_left = False

        if not node:
            return  # 存在しない

        # 2子の場合
        if node.left and node.right:
            # 左部分木の最大を探す
            max_left = node.left
            max_path = [node]
            while max_left.right:
                max_path.append(max_left)
                max_left = max_left.right
            node.value = max_left.value
            # max_left を削除
            node = max_left
            path += max_path[1:]
            parent = path[-1] if path else None
            is_left = parent.left == node if parent else True

        # 1子または0子
        child = node.left if node.left else node.right
        if not path:
            self.root = child
        else:
            p = path[-1]
            if p.left == node:
                p.left = child
            else:
                p.right = child

        # バランス
        for n in reversed(path):
            n = self._balance(n)
        self.root = self._balance(self.root)

    # ---------------- lower_bound ----------------
    def lower_bound(self, value):
        node = self.root
        res = 0
        while node:
            if value <= node.value:
                node = node.left
            else:
                res += (node.left.size if node.left else 0) + 1
                node = node.right
        return res

    # ---------------- prefix_sum ----------------
    def prefix_sum(self, r):
        node = self.root
        res = 0
        while node and r > 0:
            left_size = node.left.size if node.left else 0
            left_sum = node.left.sum if node.left else 0
            if r <= left_size:
                node = node.left
            elif r == left_size + 1:
                res += left_sum + node.value
                break
            else:
                res += left_sum + node.value
                r -= left_size + 1
                node = node.right
        return res

# Precompute powers for encoding
POW26 = [1, 26, 26*26]

def encode(s, start):
    if isinstance(s, str):
        return (ord(s[start]) - 97) * POW26[2] + (ord(s[start+1]) - 97) * POW26[1] + (ord(s[start+2]) - 97)
    elif isinstance(s, (list, bytearray)):
        return (s[start] - 97) * POW26[2] + (s[start+1] - 97) * POW26[1] + (s[start+2] - 97)
    else:
        raise TypeError("Unsupported type for encode")

def main():
    input = sys.stdin.buffer.readline
    output = sys.stdout.buffer.write
    
    data = input().split()
    N = int(data[0])
    Q = int(data[1])
    S = input().decode().strip()
    
    state = bytearray(S, 'ascii')
    indices = [None] * (26 * 26 * 26)
    
    for i in range(N - 2):
        code = encode(state, i)
        if indices[code] is None:
            indices[code] = Set()
        indices[code].add(i + 1)
    
    out_lines = []
    for _ in range(Q):
        data = input().split()
        q = int(data[0])
        
        if q == 1:
            k = int(data[1]) - 1
            x = data[2].decode()
            
            # 境界チェックを厳密に行う
            start = max(0, k - 2)
            end = min(N - 3, k)  # N-2までのインデックスなのでN-3が最大
            
            for j in range(start, end + 1):
                c = encode(state, j)
                if indices[c] is not None:
                    indices[c].remove(j + 1)
            
            state[k] = ord(x[0])
            
            for j in range(start, end + 1):
                c = encode(state, j)
                if indices[c] is None:
                    indices[c] = Set()
                indices[c].add(j + 1)
        
        elif q == 2:
            l = int(data[1])
            r = int(data[2])
            a = data[3].decode()  # 修正: data[3]から取得
            code = encode(a, 0)
            
            if indices[code] is None:
                out_lines.append(b"0\n")
            else:
                p = indices[code].lower_bound(r - 1)
                s = indices[code].lower_bound(l)
                ans = indices[code].prefix_sum(p) - indices[code].prefix_sum(s)
                ans -= (p - s) * (l - 1)
                out_lines.append(f"{ans}\n".encode())
    
    output(b"".join(out_lines))

if __name__ == "__main__":
    main()
0