結果
問題 |
No.3239 Omnibus
|
ユーザー |
![]() |
提出日時 | 2025-08-14 22:51:23 |
言語 | Python3 (3.13.1 + numpy 2.2.1 + scipy 1.14.1) |
結果 |
AC
|
実行時間 | 9,463 ms / 10,000 ms |
コード長 | 6,297 bytes |
コンパイル時間 | 252 ms |
コンパイル使用メモリ | 12,800 KB |
実行使用メモリ | 66,252 KB |
最終ジャッジ日時 | 2025-08-14 23:54:05 |
合計ジャッジ時間 | 125,122 ms |
ジャッジサーバーID (参考情報) |
judge5 / judge1 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | AC * 1 |
other | AC * 33 |
ソースコード
import sys sys.setrecursionlimit(1 << 25) class Set: class Node: __slots__ = ['value', 'sum', 'left', 'right', 'height', 'size'] def __init__(self, value): self.value = value self.sum = value self.left = None self.right = None self.height = 1 self.size = 1 def __init__(self): self.root = None # ---------------- AVL 基本操作 ---------------- def _height(self, node): return node.height if node else 0 def _size(self, node): return node.size if node else 0 def _sum(self, node): return node.sum if node else 0 def _update(self, node): if not node: return node.height = max(self._height(node.left), self._height(node.right)) + 1 node.size = self._size(node.left) + self._size(node.right) + 1 node.sum = self._sum(node.left) + self._sum(node.right) + node.value def _rotate_left(self, node): r = node.right node.right = r.left r.left = node self._update(node) self._update(r) return r def _rotate_right(self, node): l = node.left node.left = l.right l.right = node self._update(node) self._update(l) return l def _balance(self, node): if not node: return node self._update(node) balance = self._height(node.left) - self._height(node.right) if balance >= 2: if self._height(node.left.left) < self._height(node.left.right): node.left = self._rotate_left(node.left) return self._rotate_right(node) if balance <= -2: if self._height(node.right.right) < self._height(node.right.left): node.right = self._rotate_right(node.right) return self._rotate_left(node) return node # ---------------- 要素追加 ---------------- def _add(self, node, value): if not node: return self.Node(value) if value < node.value: node.left = self._add(node.left, value) else: node.right = self._add(node.right, value) return self._balance(node) def add(self, value): self.root = self._add(self.root, value) # ---------------- 要素削除 ---------------- def _get_max(self, node): while node.right: node = node.right return node def _remove(self, node, value): if not node: return None if value < node.value: node.left = self._remove(node.left, value) elif value > node.value: node.right = self._remove(node.right, value) else: if not node.left: return node.right if not node.right: return node.left max_left = self._get_max(node.left) node.value = max_left.value node.left = self._remove(node.left, max_left.value) return self._balance(node) def remove(self, value): self.root = self._remove(self.root, value) # ---------------- lower_bound ---------------- def lower_bound(self, value): node = self.root res = 0 while node: if value <= node.value: node = node.left else: res += self._size(node.left) + 1 node = node.right return res # ---------------- prefix_sum ---------------- def prefix_sum(self, r): node = self.root res = 0 while node and r > 0: left_sz = self._size(node.left) if r <= left_sz: node = node.left elif r == left_sz + 1: res += self._sum(node.left) + node.value break else: res += self._sum(node.left) + node.value r -= left_sz + 1 node = node.right return res # Precompute powers for encoding POW26 = [1, 26, 26*26] def encode(s, start): if isinstance(s, str): return (ord(s[start]) - 97) * POW26[2] + (ord(s[start+1]) - 97) * POW26[1] + (ord(s[start+2]) - 97) elif isinstance(s, (list, bytearray)): return (s[start] - 97) * POW26[2] + (s[start+1] - 97) * POW26[1] + (s[start+2] - 97) else: raise TypeError("Unsupported type for encode") def main(): input = sys.stdin.buffer.readline output = sys.stdout.buffer.write data = input().split() N = int(data[0]) Q = int(data[1]) S = input().decode().strip() state = bytearray(S, 'ascii') indices = [None] * (26 * 26 * 26) for i in range(N - 2): code = encode(state, i) if indices[code] is None: indices[code] = Set() indices[code].add(i + 1) out_lines = [] for _ in range(Q): data = input().split() q = int(data[0]) if q == 1: k = int(data[1]) - 1 x = data[2].decode() # 境界チェックを厳密に行う start = max(0, k - 2) end = min(N - 3, k) # N-2までのインデックスなのでN-3が最大 for j in range(start, end + 1): c = encode(state, j) if indices[c] is not None: indices[c].remove(j + 1) state[k] = ord(x[0]) for j in range(start, end + 1): c = encode(state, j) if indices[c] is None: indices[c] = Set() indices[c].add(j + 1) elif q == 2: l = int(data[1]) r = int(data[2]) a = data[3].decode() # 修正: data[3]から取得 code = encode(a, 0) if indices[code] is None: out_lines.append(b"0\n") else: p = indices[code].lower_bound(r - 1) s = indices[code].lower_bound(l) ans = indices[code].prefix_sum(p) - indices[code].prefix_sum(s) ans -= (p - s) * (l - 1) out_lines.append(f"{ans}\n".encode()) output(b"".join(out_lines)) if __name__ == "__main__": main()