結果

問題 No.3239 Omnibus
ユーザー Nauclhlt🪷
提出日時 2025-08-14 22:47:19
言語 PyPy3
(7.3.15)
結果
WA  
実行時間 -
コード長 7,172 bytes
コンパイル時間 391 ms
コンパイル使用メモリ 82,276 KB
実行使用メモリ 150,272 KB
最終ジャッジ日時 2025-08-14 22:48:16
合計ジャッジ時間 49,631 ms
ジャッジサーバーID
(参考情報)
judge5 / judge4
このコードへのチャレンジ
(要ログイン)
ファイルパターン 結果
sample AC * 1
other AC * 4 WA * 26 TLE * 1 -- * 1
権限があれば一括ダウンロードができます

ソースコード

diff #

import sys
sys.setrecursionlimit(1 << 25)

class Set:
    __slots__ = ['root']
    
    class Node:
        __slots__ = ['value', 'sum', 'left', 'right', 'bias', 'height', 'size']
        
        def __init__(self, value):
            self.value = value
            self.sum = value
            self.left = None
            self.right = None
            self.bias = 0
            self.height = 1
            self.size = 1
    
    def __init__(self):
        self.root = None
    
    def update(self, node):
        if not node:
            return
        lh = node.left.height if node.left else 0
        rh = node.right.height if node.right else 0
        ls = node.left.size if node.left else 0
        rs = node.right.size if node.right else 0
        ls_sum = node.left.sum if node.left else 0
        rs_sum = node.right.sum if node.right else 0
        
        node.height = max(lh, rh) + 1
        node.size = ls + rs + 1
        node.bias = lh - rh
        node.sum = ls_sum + rs_sum + node.value
    
    def rotate_left(self, node):
        r = node.right
        node.right = r.left
        r.left = node
        self.update(node)
        self.update(r)
        return r
    
    def rotate_right(self, node):
        l = node.left
        node.left = l.right
        l.right = node
        self.update(node)
        self.update(l)
        return l
    
    def balance(self, node):
        if not node:
            return node
        
        if node.bias >= 2:
            if node.left and node.left.bias < 0:
                node.left = self.rotate_left(node.left)
            return self.rotate_right(node)
        
        if node.bias <= -2:
            if node.right and node.right.bias > 0:
                node.right = self.rotate_right(node.right)
            return self.rotate_left(node)
        
        return node
    
    def add(self, value):
        if not self.root:
            self.root = self.Node(value)
            return
        
        path = []
        cur = self.root
        while True:
            path.append(cur)
            if value < cur.value:
                if not cur.left:
                    cur.left = self.Node(value)
                    break
                cur = cur.left
            else:
                if not cur.right:
                    cur.right = self.Node(value)
                    break
                cur = cur.right
        
        for node in reversed(path):
            self.update(node)
        
        self.root = self.balance(self.root)
    
    def get_max_node(self, node):
        while node.right:
            node = node.right
        return node
    
    def remove(self, value):
        if not self.root:
            return
        
        path = []
        cur = self.root
        while cur:
            path.append(cur)
            if value == cur.value:
                break
            elif value < cur.value:
                cur = cur.left
            else:
                cur = cur.right
        else:
            return
        
        node = path[-1]
        if node.left and node.right:
            mx = self.get_max_node(node.left)
            node.value = mx.value
            value = mx.value
            cur = node.left
            path.append(cur)
            while cur.right:
                path.append(cur.right)
                cur = cur.right
            node = cur
        
        parent = path[-2] if len(path) > 1 else None
        child = node.left if node.left else node.right
        
        if parent:
            if parent.left == node:
                parent.left = child
            else:
                parent.right = child
        else:
            self.root = child
        
        for n in reversed(path[:-1]):
            self.update(n)
            if len(path) > 1 and n == path[-2]:
                self.root = self.balance(n)
    
    def lower_bound(self, value):
        if not self.root:
            return 0
        
        res = 0
        cur = self.root
        while cur:
            if value <= cur.value:
                cur = cur.left
            else:
                res += (cur.left.size if cur.left else 0) + 1
                cur = cur.right
        return res
    
    def prefix_sum(self, r):
        if not self.root or r <= 0:
            return 0
        if r >= self.root.size + 1:
            return self.root.sum
        
        res = 0
        cur = self.root
        while cur:
            left_sz = cur.left.size if cur.left else 0
            if r <= left_sz:
                cur = cur.left
            elif r == left_sz + 1:
                res += (cur.left.sum if cur.left else 0) + cur.value
                break
            else:
                res += (cur.left.sum if cur.left else 0) + cur.value
                r -= left_sz + 1
                cur = cur.right
        return res

# Precompute powers for encoding
POW26 = [1, 26, 26*26]

def encode(s, start):
    if isinstance(s, str):
        return (ord(s[start]) - 97) * POW26[2] + (ord(s[start+1]) - 97) * POW26[1] + (ord(s[start+2]) - 97)
    elif isinstance(s, (list, bytearray)):
        return (s[start] - 97) * POW26[2] + (s[start+1] - 97) * POW26[1] + (s[start+2] - 97)
    else:
        raise TypeError("Unsupported type for encode")

def main():
    input = sys.stdin.buffer.readline
    output = sys.stdout.buffer.write
    
    data = input().split()
    N = int(data[0])
    Q = int(data[1])
    S = input().decode().strip()
    
    state = bytearray(S, 'ascii')
    indices = [None] * (26 * 26 * 26)
    
    for i in range(N - 2):
        code = encode(state, i)
        if indices[code] is None:
            indices[code] = Set()
        indices[code].add(i + 1)
    
    out_lines = []
    for _ in range(Q):
        data = input().split()
        q = int(data[0])
        
        if q == 1:
            k = int(data[1]) - 1
            x = data[2].decode()
            
            # 境界チェックを厳密に行う
            start = max(0, k - 2)
            end = min(N - 3, k)  # N-2までのインデックスなのでN-3が最大
            
            for j in range(start, end + 1):
                c = encode(state, j)
                if indices[c] is not None:
                    indices[c].remove(j + 1)
            
            state[k] = ord(x[0])
            
            for j in range(start, end + 1):
                c = encode(state, j)
                if indices[c] is None:
                    indices[c] = Set()
                indices[c].add(j + 1)
        
        elif q == 2:
            l = int(data[1])
            r = int(data[2])
            a = data[3].decode()  # 修正: data[3]から取得
            code = encode(a, 0)
            
            if indices[code] is None:
                out_lines.append(b"0\n")
            else:
                p = indices[code].lower_bound(r - 1)
                s = indices[code].lower_bound(l)
                ans = indices[code].prefix_sum(p) - indices[code].prefix_sum(s)
                ans -= (p - s) * (l - 1)
                out_lines.append(f"{ans}\n".encode())
    
    output(b"".join(out_lines))

if __name__ == "__main__":
    main()
0