結果
| 問題 | 
                            No.3239 Omnibus
                             | 
                    
| コンテスト | |
| ユーザー | 
                             Nauclhlt🪷
                         | 
                    
| 提出日時 | 2025-08-14 22:41:28 | 
| 言語 | PyPy3  (7.3.15)  | 
                    
| 結果 | 
                             
                                RE
                                 
                             
                            
                         | 
                    
| 実行時間 | - | 
| コード長 | 6,785 bytes | 
| コンパイル時間 | 342 ms | 
| コンパイル使用メモリ | 82,060 KB | 
| 実行使用メモリ | 250,508 KB | 
| 最終ジャッジ日時 | 2025-08-14 22:42:15 | 
| 合計ジャッジ時間 | 42,073 ms | 
| 
                            ジャッジサーバーID (参考情報)  | 
                        judge3 / judge5 | 
(要ログイン)
| ファイルパターン | 結果 | 
|---|---|
| sample | AC * 1 | 
| other | AC * 2 WA * 11 RE * 17 TLE * 1 -- * 1 | 
ソースコード
import sys
sys.setrecursionlimit(1 << 25)
class Set:
    __slots__ = ['root']
    
    class Node:
        __slots__ = ['value', 'sum', 'left', 'right', 'bias', 'height', 'size']
        
        def __init__(self, value):
            self.value = value
            self.sum = value
            self.left = None
            self.right = None
            self.bias = 0
            self.height = 1
            self.size = 1
    
    def __init__(self):
        self.root = None
    
    def update(self, node):
        if not node:
            return
        lh = node.left.height if node.left else 0
        rh = node.right.height if node.right else 0
        ls = node.left.size if node.left else 0
        rs = node.right.size if node.right else 0
        ls_sum = node.left.sum if node.left else 0
        rs_sum = node.right.sum if node.right else 0
        
        node.height = max(lh, rh) + 1
        node.size = ls + rs + 1
        node.bias = lh - rh
        node.sum = ls_sum + rs_sum + node.value
    
    def rotate_left(self, node):
        r = node.right
        node.right = r.left
        r.left = node
        self.update(node)
        self.update(r)
        return r
    
    def rotate_right(self, node):
        l = node.left
        node.left = l.right
        l.right = node
        self.update(node)
        self.update(l)
        return l
    
    def balance(self, node):
        if not node:
            return node
        
        if node.bias >= 2:
            if node.left and node.left.bias < 0:
                node.left = self.rotate_left(node.left)
            return self.rotate_right(node)
        
        if node.bias <= -2:
            if node.right and node.right.bias > 0:
                node.right = self.rotate_right(node.right)
            return self.rotate_left(node)
        
        return node
    
    def add(self, value):
        if not self.root:
            self.root = self.Node(value)
            return
        
        path = []
        cur = self.root
        while True:
            path.append(cur)
            if value < cur.value:
                if not cur.left:
                    cur.left = self.Node(value)
                    break
                cur = cur.left
            else:
                if not cur.right:
                    cur.right = self.Node(value)
                    break
                cur = cur.right
        
        for node in reversed(path):
            self.update(node)
        
        self.root = self.balance(self.root)
    
    def get_max_node(self, node):
        while node.right:
            node = node.right
        return node
    
    def remove(self, value):
        if not self.root:
            return
        
        path = []
        cur = self.root
        while cur:
            path.append(cur)
            if value == cur.value:
                break
            elif value < cur.value:
                cur = cur.left
            else:
                cur = cur.right
        else:
            return
        
        node = path[-1]
        if node.left and node.right:
            mx = self.get_max_node(node.left)
            node.value = mx.value
            value = mx.value
            cur = node.left
            path.append(cur)
            while cur.right:
                path.append(cur.right)
                cur = cur.right
            node = cur
        
        parent = path[-2] if len(path) > 1 else None
        child = node.left if node.left else node.right
        
        if parent:
            if parent.left == node:
                parent.left = child
            else:
                parent.right = child
        else:
            self.root = child
        
        for n in reversed(path[:-1]):
            self.update(n)
            if len(path) > 1 and n == path[-2]:
                self.root = self.balance(n)
    
    def lower_bound(self, value):
        if not self.root:
            return 0
        
        res = 0
        cur = self.root
        while cur:
            if value <= cur.value:
                cur = cur.left
            else:
                res += (cur.left.size if cur.left else 0) + 1
                cur = cur.right
        return res
    
    def prefix_sum(self, r):
        if not self.root or r <= 0:
            return 0
        if r >= self.root.size + 1:
            return self.root.sum
        
        res = 0
        cur = self.root
        while cur:
            left_sz = cur.left.size if cur.left else 0
            if r <= left_sz:
                cur = cur.left
            elif r == left_sz + 1:
                res += (cur.left.sum if cur.left else 0) + cur.value
                break
            else:
                res += (cur.left.sum if cur.left else 0) + cur.value
                r -= left_sz + 1
                cur = cur.right
        return res
# Precompute powers for encoding
POW26 = [1, 26, 26*26]
def encode(s, start):
    return ((s[start]) - 97) * POW26[2] + ((s[start+1]) - 97) * POW26[1] + ((s[start+2]) - 97)
def main():
    input = sys.stdin.buffer.readline
    output = sys.stdout.buffer.write
    
    data = input().split()
    N = int(data[0])
    Q = int(data[1])
    S = input().decode().strip()
    
    state = bytearray(S, 'ascii')
    indices = [None] * (26 * 26 * 26)
    
    for i in range(N - 2):
        code = encode(state, i)
        if indices[code] is None:
            indices[code] = Set()
        indices[code].add(i + 1)
    
    out_lines = []
    for _ in range(Q):
        data = input().split()
        q = int(data[0])
        
        if q == 1:
            k = int(data[1]) - 1
            x = data[2].decode()
            
            start = max(0, k - 2)
            end = min(N - 2, k)
            
            for j in range(start, end + 1):
                c = encode(state, j)
                if indices[c] is not None:
                    indices[c].remove(j + 1)
            
            state[k] = ord(x[0])
            
            for j in range(start, end + 1):
                c = encode(state, j)
                if indices[c] is None:
                    indices[c] = Set()
                indices[c].add(j + 1)
        
        elif q == 2:
            l = int(data[1])
            r = int(data[2])
            a = data[3]
            code = encode(a, 0)
            
            if indices[code] is None:
                out_lines.append(b"0\n")
            else:
                p = indices[code].lower_bound(r - 1)
                s = indices[code].lower_bound(l)
                ans = indices[code].prefix_sum(p) - indices[code].prefix_sum(s)
                ans -= (p - s) * (l - 1)
                out_lines.append(f"{ans}\n".encode())
    
    output(b"".join(out_lines))
if __name__ == "__main__":
    main()
            
            
            
        
            
Nauclhlt🪷