結果
問題 |
No.3239 Omnibus
|
ユーザー |
![]() |
提出日時 | 2025-08-14 22:54:33 |
言語 | PyPy3 (7.3.15) |
結果 |
RE
|
実行時間 | - |
コード長 | 7,806 bytes |
コンパイル時間 | 329 ms |
コンパイル使用メモリ | 82,048 KB |
実行使用メモリ | 130,048 KB |
最終ジャッジ日時 | 2025-08-14 22:54:57 |
合計ジャッジ時間 | 14,637 ms |
ジャッジサーバーID (参考情報) |
judge2 / judge3 |
(要ログイン)
ファイルパターン | 結果 |
---|---|
sample | RE * 1 |
other | WA * 2 RE * 30 |
ソースコード
import sys sys.setrecursionlimit(1 << 25) class Set: def __init__(self): self.root = None class Node: __slots__ = ['value', 'sum', 'size', 'height', 'left', 'right'] def __init__(self, value): self.value = value self.sum = value self.size = 1 self.height = 1 self.left = None self.right = None # ---------------- ヘルパー ---------------- def _update(self, node): if not node: return lh = node.left.height if node.left else 0 rh = node.right.height if node.right else 0 ls = node.left.size if node.left else 0 rs = node.right.size if node.right else 0 ls_sum = node.left.sum if node.left else 0 rs_sum = node.right.sum if node.right else 0 node.height = max(lh, rh) + 1 node.size = ls + rs + 1 node.sum = ls_sum + rs_sum + node.value def _rotate_left(self, node): r = node.right node.right = r.left r.left = node self._update(node) self._update(r) return r def _rotate_right(self, node): l = node.left node.left = l.right l.right = node self._update(node) self._update(l) return l def _balance(self, node): self._update(node) balance = (node.left.height if node.left else 0) - (node.right.height if node.right else 0) if balance >= 2: if node.left and (node.left.left.height if node.left.left else 0) < (node.left.right.height if node.left.right else 0): node.left = self._rotate_left(node.left) return self._rotate_right(node) if balance <= -2: if node.right and (node.right.right.height if node.right.right else 0) < (node.right.left.height if node.right.left else 0): node.right = self._rotate_right(node.right) return self._rotate_left(node) return node # ---------------- add (非再帰) ---------------- def add(self, value): if not self.root: self.root = self.Node(value) return path = [] node = self.root while True: path.append(node) if value < node.value: if node.left is None: node.left = self.Node(value) break node = node.left else: if node.right is None: node.right = self.Node(value) break node = node.right # 下から上へバランス for i in reversed(range(len(path))): path[i] = self._balance(path[i]) if i == 0: self.root = path[i] else: parent = path[i - 1] if parent.left and parent.left.value == path[i].value or value < parent.value: parent.left = path[i] else: parent.right = path[i] # ---------------- remove (非再帰) ---------------- def remove(self, value): node = self.root path = [] parent = None is_left = True # 探索 while node and node.value != value: path.append(node) parent = node if value < node.value: node = node.left is_left = True else: node = node.right is_left = False if not node: return # 存在しない # 2子の場合 if node.left and node.right: # 左部分木の最大を探す max_left = node.left max_path = [node] while max_left.right: max_path.append(max_left) max_left = max_left.right node.value = max_left.value # max_left を削除 node = max_left path += max_path[1:] parent = path[-1] if path else None is_left = parent.left == node if parent else True # 1子または0子 child = node.left if node.left else node.right if not path: self.root = child else: p = path[-1] if p.left == node: p.left = child else: p.right = child # バランス for n in reversed(path): n = self._balance(n) self.root = self._balance(self.root) # ---------------- lower_bound ---------------- def lower_bound(self, value): node = self.root res = 0 while node: if value <= node.value: node = node.left else: res += (node.left.size if node.left else 0) + 1 node = node.right return res # ---------------- prefix_sum ---------------- def prefix_sum(self, r): node = self.root res = 0 while node and r > 0: left_size = node.left.size if node.left else 0 left_sum = node.left.sum if node.left else 0 if r <= left_size: node = node.left elif r == left_size + 1: res += left_sum + node.value break else: res += left_sum + node.value r -= left_size + 1 node = node.right return res # Precompute powers for encoding POW26 = [1, 26, 26*26] def encode(s, start): if isinstance(s, str): return (ord(s[start]) - 97) * POW26[2] + (ord(s[start+1]) - 97) * POW26[1] + (ord(s[start+2]) - 97) elif isinstance(s, (list, bytearray)): return (s[start] - 97) * POW26[2] + (s[start+1] - 97) * POW26[1] + (s[start+2] - 97) else: raise TypeError("Unsupported type for encode") def main(): input = sys.stdin.buffer.readline output = sys.stdout.buffer.write data = input().split() N = int(data[0]) Q = int(data[1]) S = input().decode().strip() state = bytearray(S, 'ascii') indices = [None] * (26 * 26 * 26) for i in range(N - 2): code = encode(state, i) if indices[code] is None: indices[code] = Set() indices[code].add(i + 1) out_lines = [] for _ in range(Q): data = input().split() q = int(data[0]) if q == 1: k = int(data[1]) - 1 x = data[2].decode() # 境界チェックを厳密に行う start = max(0, k - 2) end = min(N - 3, k) # N-2までのインデックスなのでN-3が最大 for j in range(start, end + 1): c = encode(state, j) if indices[c] is not None: indices[c].remove(j + 1) state[k] = ord(x[0]) for j in range(start, end + 1): c = encode(state, j) if indices[c] is None: indices[c] = Set() indices[c].add(j + 1) elif q == 2: l = int(data[1]) r = int(data[2]) a = data[3].decode() # 修正: data[3]から取得 code = encode(a, 0) if indices[code] is None: out_lines.append(b"0\n") else: p = indices[code].lower_bound(r - 1) s = indices[code].lower_bound(l) ans = indices[code].prefix_sum(p) - indices[code].prefix_sum(s) ans -= (p - s) * (l - 1) out_lines.append(f"{ans}\n".encode()) output(b"".join(out_lines)) if __name__ == "__main__": main()