結果

問題 No.3239 Omnibus
ユーザー Nauclhlt🪷
提出日時 2025-08-14 22:36:17
言語 PyPy3
(7.3.15)
結果
TLE  
実行時間 -
コード長 6,477 bytes
コンパイル時間 366 ms
コンパイル使用メモリ 82,888 KB
実行使用メモリ 192,760 KB
最終ジャッジ日時 2025-08-14 22:36:26
合計ジャッジ時間 8,857 ms
ジャッジサーバーID
(参考情報)
judge3 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 4 TLE * 1 -- * 27
権限があれば一括ダウンロードができます

ソースコード

diff #

class Set:
    class Node:
        __slots__ = ['value', 'sum', 'left', 'right', 'bias', 'height', 'size']
        
        def __init__(self, value):
            self.value = value
            self.sum = value
            self.left = None
            self.right = None
            self.bias = 0
            self.height = 1
            self.size = 1
        
        def left_height(self):
            return self.left.height if self.left else 0
        
        def right_height(self):
            return self.right.height if self.right else 0
        
        def left_size(self):
            return self.left.size if self.left else 0
        
        def right_size(self):
            return self.right.size if self.right else 0
    
    def __init__(self):
        self.root = None
    
    def height_of(self, node):
        if not node:
            return 0
        return max(node.left_height(), node.right_height()) + 1
    
    def size_of(self, node):
        if not node:
            return 0
        return node.left_size() + node.right_size() + 1
    
    def sum_of(self, node):
        if not node:
            return 0
        return node.sum
    
    def update(self, node):
        if not node:
            return
        node.height = self.height_of(node)
        node.size = self.size_of(node)
        node.bias = node.left_height() - node.right_height()
        node.sum = self.sum_of(node.left) + self.sum_of(node.right) + node.value
    
    def rotate_left(self, node):
        r = node.right
        node.right = r.left
        r.left = node
        self.update(node)
        self.update(r)
        return r
    
    def rotate_right(self, node):
        l = node.left
        node.left = l.right
        l.right = node
        self.update(node)
        self.update(l)
        return l
    
    def balance(self, node):
        if not node:
            return node
        
        if node.bias >= 2:
            if node.left and node.left.bias < 0:
                node.left = self.rotate_left(node.left)
            return self.rotate_right(node)
        
        if node.bias <= -2:
            if node.right and node.right.bias > 0:
                node.right = self.rotate_right(node.right)
            return self.rotate_left(node)
        
        return node
    
    def add_rec(self, cur, value):
        if not cur:
            return self.Node(value)
        
        if value < cur.value:
            cur.left = self.add_rec(cur.left, value)
        else:
            cur.right = self.add_rec(cur.right, value)
        
        self.update(cur)
        return self.balance(cur)
    
    def get_max_node(self, node):
        while node.right:
            node = node.right
        return node
    
    def remove_rec(self, cur, value):
        if not cur:
            return None
        
        if value == cur.value:
            if cur.left and cur.right:
                mx = self.get_max_node(cur.left)
                cur.value = mx.value
                cur.left = self.remove_rec(cur.left, mx.value)
            else:
                nxt = cur.left if cur.left else cur.right
                return nxt
        elif value < cur.value:
            cur.left = self.remove_rec(cur.left, value)
        else:
            cur.right = self.remove_rec(cur.right, value)
        
        self.update(cur)
        return self.balance(cur)
    
    def lower_bound_rec(self, cur, value, acc):
        if not cur:
            return acc
        
        if value <= cur.value:
            return self.lower_bound_rec(cur.left, value, acc)
        else:
            return self.lower_bound_rec(cur.right, value, acc + cur.left_size() + 1)
    
    def prefix_sum_rec(self, cur, r):
        if not cur:
            return 0
        
        left_sz = cur.left_size()
        if r <= left_sz:
            return self.prefix_sum_rec(cur.left, r)
        elif r == left_sz + 1:
            return self.sum_of(cur.left) + cur.value
        else:
            return self.sum_of(cur.left) + cur.value + self.prefix_sum_rec(cur.right, r - left_sz - 1)
    
    def size(self):
        return self.size_of(self.root)
    
    def add(self, value):
        self.root = self.add_rec(self.root, value)
    
    def remove(self, value):
        self.root = self.remove_rec(self.root, value)
    
    def lower_bound(self, value):
        return self.lower_bound_rec(self.root, value, 0)
    
    def prefix_sum(self, r):
        if not self.root or r <= 0:
            return 0
        if r >= self.size() + 1:
            return self.root.sum
        return self.prefix_sum_rec(self.root, r)


def encode(s, start):
    return (ord(s[start]) - ord('a')) * 26 * 26 + (ord(s[start + 1]) - ord('a')) * 26 + (ord(s[start + 2]) - ord('a'))


def main():
    import sys
    input = sys.stdin.readline
    output = sys.stdout.write
    
    data = input().split()
    N = int(data[0])
    Q = int(data[1])
    S = input().strip()
    
    state = list(S)
    indices = [None] * (26 * 26 * 26)
    
    for i in range(N - 2):
        code = encode(state, i)
        if indices[code] is None:
            indices[code] = Set()
        indices[code].add(i + 1)
    
    for _ in range(Q):
        data = input().split()
        q = int(data[0])
        
        if q == 1:
            k = int(data[1])
            x = data[2].strip()
            k -= 1
            
            for j in range(k - 2, k + 1):
                if j < 0 or j >= N - 2:
                    continue
                c = encode(state, j)
                if indices[c] is not None:
                    indices[c].remove(j + 1)
            
            state[k] = x
            
            for j in range(k - 2, k + 1):
                if j < 0 or j >= N - 2:
                    continue
                c = encode(state, j)
                if indices[c] is None:
                    indices[c] = Set()
                indices[c].add(j + 1)
        
        elif q == 2:
            l = int(data[1])
            r = int(data[2])
            a = data[3].strip()
            code = encode(a, 0)
            
            if indices[code] is None:
                output("0\n")
            else:
                p = indices[code].lower_bound(r - 1)
                s = indices[code].lower_bound(l)
                ans = indices[code].prefix_sum(p) - indices[code].prefix_sum(s)
                ans -= (p - s) * (l - 1)
                output(f"{ans}\n")


if __name__ == "__main__":
    main()
0